rlx_runtime/
logit_verify.rs1#[derive(Debug, Clone, Copy)]
30pub struct Tolerance {
31 pub max_abs: f32,
32 pub max_rel: f32,
33 pub min_cosine: f32,
34}
35
36impl Tolerance {
37 pub const STRICT: Self = Self {
39 max_abs: 1e-4,
40 max_rel: 1e-3,
41 min_cosine: 0.9999,
42 };
43 pub const LOOSE_F16: Self = Self {
45 max_abs: 5e-2,
46 max_rel: 5e-2,
47 min_cosine: 0.999,
48 };
49}
50
51#[derive(Debug, Clone)]
52pub struct Diff {
53 pub n: usize,
54 pub max_abs: f32,
55 pub max_rel: f32,
56 pub mean_abs: f32,
57 pub cosine: f32,
58 pub argmax_diff_index: usize,
59}
60
61#[derive(Debug)]
62pub enum VerifyError {
63 LengthMismatch { got: usize, expected: usize },
64 ToleranceExceeded { diff: Diff, tolerance: Tolerance },
65 Nonfinite,
66}
67
68impl std::fmt::Display for VerifyError {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 match self {
71 Self::LengthMismatch { got, expected } => {
72 write!(f, "length mismatch: got {got}, expected {expected}")
73 }
74 Self::ToleranceExceeded { diff, tolerance } => write!(
75 f,
76 "tolerance exceeded: max_abs={:.3e} max_rel={:.3e} cos={:.5} \
77 (limits abs={:.3e}, rel={:.3e}, cos>={:.5})",
78 diff.max_abs,
79 diff.max_rel,
80 diff.cosine,
81 tolerance.max_abs,
82 tolerance.max_rel,
83 tolerance.min_cosine
84 ),
85 Self::Nonfinite => write!(f, "non-finite value (NaN or inf) in output"),
86 }
87 }
88}
89
90impl std::error::Error for VerifyError {}
91
92pub fn diff(out: &[f32], reference: &[f32]) -> Result<Diff, VerifyError> {
95 if out.len() != reference.len() {
96 return Err(VerifyError::LengthMismatch {
97 got: out.len(),
98 expected: reference.len(),
99 });
100 }
101 let mut max_abs = 0f32;
102 let mut max_rel = 0f32;
103 let mut sum_abs = 0f32;
104 let mut argmax = 0usize;
105 let mut dot = 0f32;
106 let mut na = 0f32;
107 let mut nb = 0f32;
108 for (i, (&a, &b)) in out.iter().zip(reference).enumerate() {
109 if !a.is_finite() || !b.is_finite() {
110 return Err(VerifyError::Nonfinite);
111 }
112 let d = (a - b).abs();
113 sum_abs += d;
114 if d > max_abs {
115 max_abs = d;
116 argmax = i;
117 }
118 let denom = b.abs().max(1e-12);
119 let rel = d / denom;
120 if rel > max_rel {
121 max_rel = rel;
122 }
123 dot += a * b;
124 na += a * a;
125 nb += b * b;
126 }
127 let n = out.len();
128 let cosine = if na > 0.0 && nb > 0.0 {
129 dot / (na.sqrt() * nb.sqrt())
130 } else {
131 1.0
132 };
133 Ok(Diff {
134 n,
135 max_abs,
136 max_rel,
137 mean_abs: if n > 0 { sum_abs / n as f32 } else { 0.0 },
138 cosine,
139 argmax_diff_index: argmax,
140 })
141}
142
143pub fn compare(out: &[f32], reference: &[f32], tol: Tolerance) -> Result<Diff, VerifyError> {
145 let d = diff(out, reference)?;
146 if d.max_abs > tol.max_abs || d.max_rel > tol.max_rel || d.cosine < tol.min_cosine {
147 return Err(VerifyError::ToleranceExceeded {
148 diff: d,
149 tolerance: tol,
150 });
151 }
152 Ok(d)
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn identical_passes() {
161 let a = vec![1.0, 2.0, 3.0];
162 let r = compare(&a, &a, Tolerance::STRICT).unwrap();
163 assert_eq!(r.max_abs, 0.0);
164 assert!((r.cosine - 1.0).abs() < 1e-6);
165 }
166
167 #[test]
168 fn tiny_diff_passes_strict() {
169 let a = [1.0f32, 2.0, 3.0];
170 let b = [1.0 + 1e-6, 2.0 - 1e-6, 3.0];
171 compare(&a, &b, Tolerance::STRICT).unwrap();
172 }
173
174 #[test]
175 fn big_diff_fails() {
176 let a = [1.0f32, 2.0, 3.0];
177 let b = [1.0, 2.0, 99.0];
178 let err = compare(&a, &b, Tolerance::STRICT).unwrap_err();
179 assert!(matches!(err, VerifyError::ToleranceExceeded { .. }));
180 }
181
182 #[test]
183 fn nan_is_hard_fail() {
184 let a = [1.0f32, f32::NAN];
185 let b = [1.0, 0.0];
186 let err = compare(&a, &b, Tolerance::STRICT).unwrap_err();
187 assert!(matches!(err, VerifyError::Nonfinite));
188 }
189}