Skip to main content

model_runtime/
conformance.rs

1use crate::{ModelRuntimeError, Result};
2
3use crate::RawPrediction;
4
5#[derive(Debug, Clone, Copy, PartialEq)]
6/// Data type for blue/green prediction test options.
7pub struct BlueGreenPredictionTestOptions {
8    /// Maximum allowed absolute score delta.
9    pub max_score_delta: f32,
10    /// Whether bounding boxes are compared.
11    pub compare_regions: bool,
12}
13
14impl Default for BlueGreenPredictionTestOptions {
15    fn default() -> Self {
16        Self {
17            max_score_delta: 1.0e-4,
18            compare_regions: true,
19        }
20    }
21}
22
23#[derive(Debug, Clone, PartialEq)]
24/// Data type for blue/green prediction test report.
25pub struct BlueGreenPredictionTestReport {
26    /// Number of predictions compared.
27    pub compared_predictions: usize,
28    /// Maximum observed absolute score delta.
29    pub max_score_delta: f32,
30}
31
32/// Compares green and blue model predictions for runtime conformance tests.
33pub fn compare_blue_green_predictions(
34    green: &[RawPrediction],
35    blue: &[RawPrediction],
36    options: BlueGreenPredictionTestOptions,
37) -> Result<BlueGreenPredictionTestReport> {
38    if green.len() != blue.len() {
39        return Err(ModelRuntimeError::InvalidArgument(format!(
40            "blue/green prediction counts differ: green={}, blue={}",
41            green.len(),
42            blue.len()
43        )));
44    }
45
46    let mut max_score_delta = 0.0_f32;
47    for (index, (green, blue)) in green.iter().zip(blue).enumerate() {
48        if green.kind != blue.kind {
49            return Err(blue_green_mismatch(index, "kind", &green.kind, &blue.kind));
50        }
51        if green.label != blue.label {
52            return Err(blue_green_mismatch(
53                index,
54                "label",
55                &green.label,
56                &blue.label,
57            ));
58        }
59        if green.text != blue.text {
60            return Err(blue_green_mismatch(index, "text", &green.text, &blue.text));
61        }
62        match (green.score, blue.score) {
63            (Some(green_score), Some(blue_score)) => {
64                if !green_score.is_finite() || !blue_score.is_finite() {
65                    return Err(ModelRuntimeError::InvalidArgument(format!(
66                        "blue/green prediction {index} has a non-finite score"
67                    )));
68                }
69                let delta = (green_score - blue_score).abs();
70                max_score_delta = max_score_delta.max(delta);
71                if delta > options.max_score_delta {
72                    return Err(ModelRuntimeError::InvalidArgument(format!(
73                        "blue/green prediction {index} score delta {delta} exceeds {}",
74                        options.max_score_delta
75                    )));
76                }
77            }
78            (None, None) => {}
79            _ => {
80                return Err(blue_green_mismatch(
81                    index,
82                    "score",
83                    &green.score,
84                    &blue.score,
85                ))
86            }
87        }
88        let _ = options.compare_regions;
89    }
90
91    Ok(BlueGreenPredictionTestReport {
92        compared_predictions: green.len(),
93        max_score_delta,
94    })
95}
96
97fn blue_green_mismatch<T: std::fmt::Debug>(
98    index: usize,
99    field: &str,
100    green: &T,
101    blue: &T,
102) -> ModelRuntimeError {
103    ModelRuntimeError::InvalidArgument(format!(
104        "blue/green prediction {index} {field} mismatch: green={green:?}, blue={blue:?}"
105    ))
106}