Skip to main content

rlx_runtime/
logit_verify.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Logit / output verification (plan #61).
17//!
18//! Borrowed from MAX's
19//! `tests/integration/accuracy/verify` pattern: every model gets a
20//! parity test that diffs RLX's output vs a reference (HuggingFace
21//! transformers, ONNX Runtime, hand-fused, ...) using cosine
22//! similarity, KL divergence, and absolute tolerance.
23//!
24//! Pure data layer — no HF / ORT integration here. Test code calls
25//! `compare(out, reference, tolerance)` and gets back a structured
26//! report it can `assert!` against. Hooking this up to specific
27//! reference implementations is per-bench wiring (see `burnembed`).
28
29#[derive(Debug, Clone, Copy)]
30pub struct Tolerance {
31    pub max_abs: f32,
32    pub max_rel: f32,
33    pub min_cosine: f32,
34}
35
36impl Tolerance {
37    /// Strict — for f32-vs-f32 comparisons.
38    pub const STRICT: Self = Self {
39        max_abs: 1e-4,
40        max_rel: 1e-3,
41        min_cosine: 0.9999,
42    };
43    /// Loose — for f16 / bf16 against f32 reference.
44    pub const LOOSE_F16: Self = Self {
45        max_abs: 5e-2,
46        max_rel: 5e-2,
47        min_cosine: 0.999,
48    };
49}
50
51#[derive(Debug, Clone)]
52pub struct Diff {
53    pub n: usize,
54    pub max_abs: f32,
55    pub max_rel: f32,
56    pub mean_abs: f32,
57    pub cosine: f32,
58    pub argmax_diff_index: usize,
59}
60
61#[derive(Debug)]
62pub enum VerifyError {
63    LengthMismatch { got: usize, expected: usize },
64    ToleranceExceeded { diff: Diff, tolerance: Tolerance },
65    Nonfinite,
66}
67
68impl std::fmt::Display for VerifyError {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        match self {
71            Self::LengthMismatch { got, expected } => {
72                write!(f, "length mismatch: got {got}, expected {expected}")
73            }
74            Self::ToleranceExceeded { diff, tolerance } => write!(
75                f,
76                "tolerance exceeded: max_abs={:.3e} max_rel={:.3e} cos={:.5} \
77                    (limits abs={:.3e}, rel={:.3e}, cos>={:.5})",
78                diff.max_abs,
79                diff.max_rel,
80                diff.cosine,
81                tolerance.max_abs,
82                tolerance.max_rel,
83                tolerance.min_cosine
84            ),
85            Self::Nonfinite => write!(f, "non-finite value (NaN or inf) in output"),
86        }
87    }
88}
89
90impl std::error::Error for VerifyError {}
91
92/// Compute the per-pair diff between `out` and `ref`. NaN/inf in
93/// either side is a hard fail.
94pub fn diff(out: &[f32], reference: &[f32]) -> Result<Diff, VerifyError> {
95    if out.len() != reference.len() {
96        return Err(VerifyError::LengthMismatch {
97            got: out.len(),
98            expected: reference.len(),
99        });
100    }
101    let mut max_abs = 0f32;
102    let mut max_rel = 0f32;
103    let mut sum_abs = 0f32;
104    let mut argmax = 0usize;
105    let mut dot = 0f32;
106    let mut na = 0f32;
107    let mut nb = 0f32;
108    for (i, (&a, &b)) in out.iter().zip(reference).enumerate() {
109        if !a.is_finite() || !b.is_finite() {
110            return Err(VerifyError::Nonfinite);
111        }
112        let d = (a - b).abs();
113        sum_abs += d;
114        if d > max_abs {
115            max_abs = d;
116            argmax = i;
117        }
118        let denom = b.abs().max(1e-12);
119        let rel = d / denom;
120        if rel > max_rel {
121            max_rel = rel;
122        }
123        dot += a * b;
124        na += a * a;
125        nb += b * b;
126    }
127    let n = out.len();
128    let cosine = if na > 0.0 && nb > 0.0 {
129        dot / (na.sqrt() * nb.sqrt())
130    } else {
131        1.0
132    };
133    Ok(Diff {
134        n,
135        max_abs,
136        max_rel,
137        mean_abs: if n > 0 { sum_abs / n as f32 } else { 0.0 },
138        cosine,
139        argmax_diff_index: argmax,
140    })
141}
142
143/// `diff` + tolerance check. Use this in tests / parity harnesses.
144pub fn compare(out: &[f32], reference: &[f32], tol: Tolerance) -> Result<Diff, VerifyError> {
145    let d = diff(out, reference)?;
146    if d.max_abs > tol.max_abs || d.max_rel > tol.max_rel || d.cosine < tol.min_cosine {
147        return Err(VerifyError::ToleranceExceeded {
148            diff: d,
149            tolerance: tol,
150        });
151    }
152    Ok(d)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn identical_passes() {
161        let a = vec![1.0, 2.0, 3.0];
162        let r = compare(&a, &a, Tolerance::STRICT).unwrap();
163        assert_eq!(r.max_abs, 0.0);
164        assert!((r.cosine - 1.0).abs() < 1e-6);
165    }
166
167    #[test]
168    fn tiny_diff_passes_strict() {
169        let a = [1.0f32, 2.0, 3.0];
170        let b = [1.0 + 1e-6, 2.0 - 1e-6, 3.0];
171        compare(&a, &b, Tolerance::STRICT).unwrap();
172    }
173
174    #[test]
175    fn big_diff_fails() {
176        let a = [1.0f32, 2.0, 3.0];
177        let b = [1.0, 2.0, 99.0];
178        let err = compare(&a, &b, Tolerance::STRICT).unwrap_err();
179        assert!(matches!(err, VerifyError::ToleranceExceeded { .. }));
180    }
181
182    #[test]
183    fn nan_is_hard_fail() {
184        let a = [1.0f32, f32::NAN];
185        let b = [1.0, 0.0];
186        let err = compare(&a, &b, Tolerance::STRICT).unwrap_err();
187        assert!(matches!(err, VerifyError::Nonfinite));
188    }
189}