model_runtime/
conformance.rs1use crate::{ModelRuntimeError, Result};
2
3use crate::RawPrediction;
4
5#[derive(Debug, Clone, Copy, PartialEq)]
6pub struct BlueGreenPredictionTestOptions {
8 pub max_score_delta: f32,
10 pub compare_regions: bool,
12}
13
14impl Default for BlueGreenPredictionTestOptions {
15 fn default() -> Self {
16 Self {
17 max_score_delta: 1.0e-4,
18 compare_regions: true,
19 }
20 }
21}
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct BlueGreenPredictionTestReport {
26 pub compared_predictions: usize,
28 pub max_score_delta: f32,
30}
31
32pub fn compare_blue_green_predictions(
34 green: &[RawPrediction],
35 blue: &[RawPrediction],
36 options: BlueGreenPredictionTestOptions,
37) -> Result<BlueGreenPredictionTestReport> {
38 if green.len() != blue.len() {
39 return Err(ModelRuntimeError::InvalidArgument(format!(
40 "blue/green prediction counts differ: green={}, blue={}",
41 green.len(),
42 blue.len()
43 )));
44 }
45
46 let mut max_score_delta = 0.0_f32;
47 for (index, (green, blue)) in green.iter().zip(blue).enumerate() {
48 if green.kind != blue.kind {
49 return Err(blue_green_mismatch(index, "kind", &green.kind, &blue.kind));
50 }
51 if green.label != blue.label {
52 return Err(blue_green_mismatch(
53 index,
54 "label",
55 &green.label,
56 &blue.label,
57 ));
58 }
59 if green.text != blue.text {
60 return Err(blue_green_mismatch(index, "text", &green.text, &blue.text));
61 }
62 match (green.score, blue.score) {
63 (Some(green_score), Some(blue_score)) => {
64 if !green_score.is_finite() || !blue_score.is_finite() {
65 return Err(ModelRuntimeError::InvalidArgument(format!(
66 "blue/green prediction {index} has a non-finite score"
67 )));
68 }
69 let delta = (green_score - blue_score).abs();
70 max_score_delta = max_score_delta.max(delta);
71 if delta > options.max_score_delta {
72 return Err(ModelRuntimeError::InvalidArgument(format!(
73 "blue/green prediction {index} score delta {delta} exceeds {}",
74 options.max_score_delta
75 )));
76 }
77 }
78 (None, None) => {}
79 _ => {
80 return Err(blue_green_mismatch(
81 index,
82 "score",
83 &green.score,
84 &blue.score,
85 ))
86 }
87 }
88 let _ = options.compare_regions;
89 }
90
91 Ok(BlueGreenPredictionTestReport {
92 compared_predictions: green.len(),
93 max_score_delta,
94 })
95}
96
97fn blue_green_mismatch<T: std::fmt::Debug>(
98 index: usize,
99 field: &str,
100 green: &T,
101 blue: &T,
102) -> ModelRuntimeError {
103 ModelRuntimeError::InvalidArgument(format!(
104 "blue/green prediction {index} {field} mismatch: green={green:?}, blue={blue:?}"
105 ))
106}