rlx_runtime/nan_check.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NaN/inf check epilogue (plan #18).
17//!
18//! Borrowed from MAX's `Mogg/MOGGKernelAPI/nan_check.mojo` pattern.
19//! When the `nan-check` Cargo feature is on, [`scan`] reports the
20//! first NaN or inf in a slice — useful as a debug epilogue on
21//! every output buffer to localize precision blow-ups to the op
22//! that introduced them.
23//!
24//! Always present in the API surface so callers can compile against
25//! it; the feature flag controls whether it's a real scan or a
26//! no-op (returns `Ok(())` immediately).
27
28/// What was found in a buffer that fails the check.
29#[derive(Debug, Clone, Copy)]
30pub enum BadValue {
31 Nan,
32 PosInf,
33 NegInf,
34}
35
36#[derive(Debug)]
37pub struct NanCheckError {
38 pub kind: BadValue,
39 pub index: usize,
40 pub label: String,
41}
42
43impl std::fmt::Display for NanCheckError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 let what = match self.kind {
46 BadValue::Nan => "NaN",
47 BadValue::PosInf => "+inf",
48 BadValue::NegInf => "-inf",
49 };
50 write!(f, "{} at index {} of `{}`", what, self.index, self.label)
51 }
52}
53
54impl std::error::Error for NanCheckError {}
55
56/// Scan `data` for the first NaN or infinity. With the `nan-check`
57/// feature OFF, returns `Ok(())` immediately (the optimizer
58/// eliminates the call). With it ON, walks the slice — the cost is
59/// O(n) but only paid when a caller opts in.
60#[cfg(feature = "nan-check")]
61#[inline(always)]
62pub fn scan(label: &str, data: &[f32]) -> Result<(), NanCheckError> {
63 for (i, &v) in data.iter().enumerate() {
64 if v.is_nan() {
65 return Err(NanCheckError {
66 kind: BadValue::Nan,
67 index: i,
68 label: label.to_string(),
69 });
70 }
71 if v.is_infinite() {
72 let kind = if v > 0.0 {
73 BadValue::PosInf
74 } else {
75 BadValue::NegInf
76 };
77 return Err(NanCheckError {
78 kind,
79 index: i,
80 label: label.to_string(),
81 });
82 }
83 }
84 Ok(())
85}
86
87#[cfg(not(feature = "nan-check"))]
88#[inline(always)]
89pub fn scan(_label: &str, _data: &[f32]) -> Result<(), NanCheckError> {
90 Ok(())
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn clean_data_passes() {
99 let data = [1.0, 2.0, -3.5, 0.0];
100 assert!(scan("clean", &data).is_ok());
101 }
102
103 #[cfg(feature = "nan-check")]
104 #[test]
105 fn detects_nan() {
106 let data = [1.0, f32::NAN, 3.0];
107 let err = scan("nan", &data).unwrap_err();
108 assert!(matches!(err.kind, BadValue::Nan));
109 assert_eq!(err.index, 1);
110 }
111
112 #[cfg(feature = "nan-check")]
113 #[test]
114 fn detects_pos_inf() {
115 let data = [f32::INFINITY, 0.0];
116 let err = scan("inf", &data).unwrap_err();
117 assert!(matches!(err.kind, BadValue::PosInf));
118 assert_eq!(err.index, 0);
119 }
120}