1#![allow(missing_docs)]
2#[derive(Debug, Clone, PartialEq)]
28pub enum JidokaError {
29 NumericalDeviation { computed: f32, expected: f32, relative_error: f32 },
31 NaNDetected { location: &'static str },
33 InfDetected { location: &'static str },
35 DimensionMismatch { expected: (usize, usize, usize), actual: (usize, usize, usize) },
37}
38
39impl std::fmt::Display for JidokaError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::NumericalDeviation { computed, expected, relative_error } => {
43 write!(
44 f,
45 "Jidoka: numerical deviation - computed={}, expected={}, error={}",
46 computed, expected, relative_error
47 )
48 }
49 Self::NaNDetected { location } => {
50 write!(f, "Jidoka: NaN detected at {}", location)
51 }
52 Self::InfDetected { location } => {
53 write!(f, "Jidoka: Inf detected at {}", location)
54 }
55 Self::DimensionMismatch { expected, actual } => {
56 write!(f, "Jidoka: dimension mismatch - expected {:?}, got {:?}", expected, actual)
57 }
58 }
59 }
60}
61
62impl std::error::Error for JidokaError {}
63
64#[derive(Debug, Clone)]
66pub struct JidokaGuard {
67 pub epsilon: f32,
69 pub check_special: bool,
71 pub sample_rate: usize,
73}
74
75impl Default for JidokaGuard {
76 fn default() -> Self {
77 Self {
78 epsilon: 1e-5,
79 check_special: true,
80 sample_rate: 1000, }
82 }
83}
84
85impl JidokaGuard {
86 pub fn strict() -> Self {
88 Self { epsilon: 1e-6, check_special: true, sample_rate: 1 }
89 }
90
91 #[inline]
93 pub fn validate(&self, computed: f32, expected: f32) -> Result<(), JidokaError> {
94 if self.check_special {
95 if computed.is_nan() {
96 return Err(JidokaError::NaNDetected { location: "output" });
97 }
98 if computed.is_infinite() {
99 return Err(JidokaError::InfDetected { location: "output" });
100 }
101 }
102
103 let abs_diff = (computed - expected).abs();
104 let max_abs = computed.abs().max(expected.abs()).max(1e-10);
105 let relative_error = abs_diff / max_abs;
106
107 if relative_error > self.epsilon {
108 return Err(JidokaError::NumericalDeviation { computed, expected, relative_error });
109 }
110
111 Ok(())
112 }
113
114 #[inline]
116 pub fn check_input(&self, value: f32, location: &'static str) -> Result<(), JidokaError> {
117 if !self.check_special {
118 return Ok(());
119 }
120 if value.is_nan() {
121 return Err(JidokaError::NaNDetected { location });
122 }
123 if value.is_infinite() {
124 return Err(JidokaError::InfDetected { location });
125 }
126 Ok(())
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_jidoka_default() {
136 let guard = JidokaGuard::default();
137 assert!((guard.epsilon - 1e-5).abs() < 1e-10);
138 assert!(guard.check_special);
139 assert_eq!(guard.sample_rate, 1000);
140 }
141
142 #[test]
143 fn test_jidoka_strict() {
144 let guard = JidokaGuard::strict();
145 assert!((guard.epsilon - 1e-6).abs() < 1e-10);
146 assert!(guard.check_special);
147 assert_eq!(guard.sample_rate, 1);
148 }
149
150 #[test]
151 fn test_validate_pass() {
152 let guard = JidokaGuard::default();
153 assert!(guard.validate(1.0, 1.0).is_ok());
154 assert!(guard.validate(1.0, 1.000001).is_ok());
155 }
156
157 #[test]
158 fn test_validate_nan() {
159 let guard = JidokaGuard::default();
160 let result = guard.validate(f32::NAN, 1.0);
161 assert!(matches!(result, Err(JidokaError::NaNDetected { .. })));
162 }
163
164 #[test]
165 fn test_validate_inf() {
166 let guard = JidokaGuard::default();
167 let result = guard.validate(f32::INFINITY, 1.0);
168 assert!(matches!(result, Err(JidokaError::InfDetected { .. })));
169 }
170
171 #[test]
172 fn test_validate_deviation() {
173 let guard = JidokaGuard::strict();
174 let result = guard.validate(1.0, 2.0);
175 assert!(matches!(result, Err(JidokaError::NumericalDeviation { .. })));
176 }
177
178 #[test]
179 fn test_check_input_nan() {
180 let guard = JidokaGuard::default();
181 let result = guard.check_input(f32::NAN, "test");
182 assert!(matches!(result, Err(JidokaError::NaNDetected { .. })));
183 }
184
185 #[test]
186 fn test_check_input_inf() {
187 let guard = JidokaGuard::default();
188 let result = guard.check_input(f32::INFINITY, "test");
189 assert!(matches!(result, Err(JidokaError::InfDetected { .. })));
190 }
191
192 #[test]
193 fn test_error_display() {
194 let err = JidokaError::NaNDetected { location: "test" };
195 assert!(format!("{}", err).contains("NaN"));
196
197 let err = JidokaError::InfDetected { location: "test" };
198 assert!(format!("{}", err).contains("Inf"));
199
200 let err =
201 JidokaError::NumericalDeviation { computed: 1.0, expected: 2.0, relative_error: 0.5 };
202 assert!(format!("{}", err).contains("deviation"));
203
204 let err = JidokaError::DimensionMismatch { expected: (1, 2, 3), actual: (4, 5, 6) };
205 assert!(format!("{}", err).contains("mismatch"));
206 }
207}