Skip to main content

rlx_runtime/
nan_check.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NaN/inf check epilogue (plan #18).
17//!
18//! Borrowed from MAX's `Mogg/MOGGKernelAPI/nan_check.mojo` pattern.
19//! When the `nan-check` Cargo feature is on, [`scan`] reports the
20//! first NaN or inf in a slice — useful as a debug epilogue on
21//! every output buffer to localize precision blow-ups to the op
22//! that introduced them.
23//!
24//! Always present in the API surface so callers can compile against
25//! it; the feature flag controls whether it's a real scan or a
26//! no-op (returns `Ok(())` immediately).
27
28/// What was found in a buffer that fails the check.
29#[derive(Debug, Clone, Copy)]
30pub enum BadValue {
31    Nan,
32    PosInf,
33    NegInf,
34}
35
36#[derive(Debug)]
37pub struct NanCheckError {
38    pub kind: BadValue,
39    pub index: usize,
40    pub label: String,
41}
42
43impl std::fmt::Display for NanCheckError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        let what = match self.kind {
46            BadValue::Nan => "NaN",
47            BadValue::PosInf => "+inf",
48            BadValue::NegInf => "-inf",
49        };
50        write!(f, "{} at index {} of `{}`", what, self.index, self.label)
51    }
52}
53
54impl std::error::Error for NanCheckError {}
55
56/// Scan `data` for the first NaN or infinity. With the `nan-check`
57/// feature OFF, returns `Ok(())` immediately (the optimizer
58/// eliminates the call). With it ON, walks the slice — the cost is
59/// O(n) but only paid when a caller opts in.
60#[cfg(feature = "nan-check")]
61#[inline(always)]
62pub fn scan(label: &str, data: &[f32]) -> Result<(), NanCheckError> {
63    for (i, &v) in data.iter().enumerate() {
64        if v.is_nan() {
65            return Err(NanCheckError {
66                kind: BadValue::Nan,
67                index: i,
68                label: label.to_string(),
69            });
70        }
71        if v.is_infinite() {
72            let kind = if v > 0.0 {
73                BadValue::PosInf
74            } else {
75                BadValue::NegInf
76            };
77            return Err(NanCheckError {
78                kind,
79                index: i,
80                label: label.to_string(),
81            });
82        }
83    }
84    Ok(())
85}
86
87#[cfg(not(feature = "nan-check"))]
88#[inline(always)]
89pub fn scan(_label: &str, _data: &[f32]) -> Result<(), NanCheckError> {
90    Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn clean_data_passes() {
99        let data = [1.0, 2.0, -3.5, 0.0];
100        assert!(scan("clean", &data).is_ok());
101    }
102
103    #[cfg(feature = "nan-check")]
104    #[test]
105    fn detects_nan() {
106        let data = [1.0, f32::NAN, 3.0];
107        let err = scan("nan", &data).unwrap_err();
108        assert!(matches!(err.kind, BadValue::Nan));
109        assert_eq!(err.index, 1);
110    }
111
112    #[cfg(feature = "nan-check")]
113    #[test]
114    fn detects_pos_inf() {
115        let data = [f32::INFINITY, 0.0];
116        let err = scan("inf", &data).unwrap_err();
117        assert!(matches!(err.kind, BadValue::PosInf));
118        assert_eq!(err.index, 0);
119    }
120}