1#[derive(Debug, Clone, PartialEq)]
8pub enum JidokaCondition {
9 NanDetected,
11 InfDetected,
13 BackendDivergence {
15 tolerance: f32,
17 },
18 PerformanceRegression {
20 threshold_pct: f32,
22 },
23 DeterminismFailure,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum JidokaAction {
30 Stop,
32 LogAndContinue,
34 VisualReport,
36}
37
38#[derive(Debug, Clone)]
40pub enum JidokaError {
41 NanDetected {
43 context: String,
45 indices: Vec<usize>,
47 },
48 InfDetected {
50 context: String,
52 indices: Vec<usize>,
54 },
55 BackendDivergence {
57 context: String,
59 max_diff: f32,
61 tolerance: f32,
63 },
64 PerformanceRegression {
66 context: String,
68 regression_pct: f32,
70 threshold_pct: f32,
72 },
73 DeterminismFailure {
75 context: String,
77 first_diff_index: usize,
79 },
80}
81
82impl std::fmt::Display for JidokaError {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 Self::NanDetected { context, indices } => {
86 write!(f, "Jidoka: NaN detected at {context} (indices: {indices:?})")
87 }
88 Self::InfDetected { context, indices } => {
89 write!(f, "Jidoka: Infinity detected at {context} (indices: {indices:?})")
90 }
91 Self::BackendDivergence { context, max_diff, tolerance } => {
92 write!(
93 f,
94 "Jidoka: Backend divergence at {context} (max_diff: {max_diff}, tolerance: {tolerance})"
95 )
96 }
97 Self::PerformanceRegression { context, regression_pct, threshold_pct } => {
98 write!(
99 f,
100 "Jidoka: Performance regression at {context} ({regression_pct:.2}% > {threshold_pct:.2}%)"
101 )
102 }
103 Self::DeterminismFailure { context, first_diff_index } => {
104 write!(
105 f,
106 "Jidoka: Determinism failure at {context} (first diff at index {first_diff_index})"
107 )
108 }
109 }
110 }
111}
112
113impl std::error::Error for JidokaError {}
114
115#[derive(Debug, Clone)]
120pub struct JidokaGuard {
121 pub condition: JidokaCondition,
123 pub action: JidokaAction,
125 pub context: String,
127}
128
129impl JidokaGuard {
130 #[must_use]
132 pub fn new(
133 condition: JidokaCondition,
134 action: JidokaAction,
135 context: impl Into<String>,
136 ) -> Self {
137 Self { condition, action, context: context.into() }
138 }
139
140 #[must_use]
142 pub fn nan_guard(context: impl Into<String>) -> Self {
143 Self::new(JidokaCondition::NanDetected, JidokaAction::Stop, context)
144 }
145
146 #[must_use]
148 pub fn inf_guard(context: impl Into<String>) -> Self {
149 Self::new(JidokaCondition::InfDetected, JidokaAction::Stop, context)
150 }
151
152 #[must_use]
154 pub fn divergence_guard(tolerance: f32, context: impl Into<String>) -> Self {
155 Self::new(JidokaCondition::BackendDivergence { tolerance }, JidokaAction::Stop, context)
156 }
157
158 pub fn check_output(&self, output: &[f32]) -> Result<(), JidokaError> {
164 match &self.condition {
165 JidokaCondition::NanDetected => {
166 let nan_indices: Vec<usize> =
167 output.iter().enumerate().filter(|(_, x)| x.is_nan()).map(|(i, _)| i).collect();
168
169 if !nan_indices.is_empty() {
170 return Err(JidokaError::NanDetected {
171 context: self.context.clone(),
172 indices: nan_indices,
173 });
174 }
175 }
176 JidokaCondition::InfDetected => {
177 let inf_indices: Vec<usize> = output
178 .iter()
179 .enumerate()
180 .filter(|(_, x)| x.is_infinite())
181 .map(|(i, _)| i)
182 .collect();
183
184 if !inf_indices.is_empty() {
185 return Err(JidokaError::InfDetected {
186 context: self.context.clone(),
187 indices: inf_indices,
188 });
189 }
190 }
191 JidokaCondition::BackendDivergence { .. }
192 | JidokaCondition::PerformanceRegression { .. }
193 | JidokaCondition::DeterminismFailure => {} }
195 Ok(())
196 }
197
198 pub fn check_divergence(&self, a: &[f32], b: &[f32]) -> Result<(), JidokaError> {
204 if let JidokaCondition::BackendDivergence { tolerance } = &self.condition {
205 let max_diff =
206 a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0_f32, f32::max);
207
208 if max_diff > *tolerance {
209 return Err(JidokaError::BackendDivergence {
210 context: self.context.clone(),
211 max_diff,
212 tolerance: *tolerance,
213 });
214 }
215 }
216 Ok(())
217 }
218
219 pub fn check_determinism(&self, a: &[f32], b: &[f32]) -> Result<(), JidokaError> {
225 if let JidokaCondition::DeterminismFailure = &self.condition {
226 for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
227 if x.to_bits() != y.to_bits() {
229 return Err(JidokaError::DeterminismFailure {
230 context: self.context.clone(),
231 first_diff_index: i,
232 });
233 }
234 }
235 }
236 Ok(())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_jidoka_nan_detection() {
246 let guard = JidokaGuard::nan_guard("test_operation");
248 let output_with_nan = vec![1.0, 2.0, f32::NAN, 4.0];
249
250 let result = guard.check_output(&output_with_nan);
251 assert!(result.is_err());
252
253 if let Err(JidokaError::NanDetected { indices, .. }) = result {
254 assert_eq!(indices, vec![2]);
255 } else {
256 panic!("Expected NanDetected error");
257 }
258 }
259
260 #[test]
261 fn test_jidoka_nan_no_false_positive() {
262 let guard = JidokaGuard::nan_guard("test_operation");
263 let clean_output = vec![1.0, 2.0, 3.0, 4.0];
264
265 let result = guard.check_output(&clean_output);
266 assert!(result.is_ok());
267 }
268
269 #[test]
270 fn test_jidoka_inf_detection() {
271 let guard = JidokaGuard::inf_guard("test_operation");
273 let output_with_inf = vec![1.0, f32::INFINITY, 3.0, f32::NEG_INFINITY];
274
275 let result = guard.check_output(&output_with_inf);
276 assert!(result.is_err());
277
278 if let Err(JidokaError::InfDetected { indices, .. }) = result {
279 assert_eq!(indices, vec![1, 3]);
280 } else {
281 panic!("Expected InfDetected error");
282 }
283 }
284
285 #[test]
286 fn test_jidoka_divergence_detection() {
287 let guard = JidokaGuard::divergence_guard(1e-5, "cross_backend");
289 let a = vec![1.0, 2.0, 3.0, 4.0];
290 let b = vec![1.0, 2.0, 3.1, 4.0]; let result = guard.check_divergence(&a, &b);
293 assert!(result.is_err());
294
295 if let Err(JidokaError::BackendDivergence { max_diff, .. }) = result {
296 assert!((max_diff - 0.1).abs() < 1e-6);
297 } else {
298 panic!("Expected BackendDivergence error");
299 }
300 }
301
302 #[test]
303 fn test_jidoka_divergence_within_tolerance() {
304 let guard = JidokaGuard::divergence_guard(1e-5, "cross_backend");
305 let a = vec![1.0, 2.0, 3.0, 4.0];
306 let b = vec![1.0, 2.0, 3.0 + 1e-7, 4.0]; let result = guard.check_divergence(&a, &b);
309 assert!(result.is_ok());
310 }
311
312 #[test]
313 fn test_jidoka_determinism_check() {
314 let guard = JidokaGuard::new(
316 JidokaCondition::DeterminismFailure,
317 JidokaAction::Stop,
318 "determinism_test",
319 );
320
321 let a = vec![1.0, 2.0, 3.0, 4.0];
322 let b = vec![1.0, 2.0, 3.0, 4.0];
323
324 let result = guard.check_determinism(&a, &b);
325 assert!(result.is_ok());
326 }
327
328 #[test]
329 fn test_jidoka_determinism_failure() {
330 let guard = JidokaGuard::new(
331 JidokaCondition::DeterminismFailure,
332 JidokaAction::Stop,
333 "determinism_test",
334 );
335
336 let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
337 let b: Vec<f32> = vec![1.0, 2.0, 3.000_001, 4.0]; assert_ne!(a[2].to_bits(), b[2].to_bits(), "Test values must differ");
341
342 let result = guard.check_determinism(&a, &b);
343 assert!(result.is_err());
344
345 if let Err(JidokaError::DeterminismFailure { first_diff_index, .. }) = result {
346 assert_eq!(first_diff_index, 2);
347 } else {
348 panic!("Expected DeterminismFailure error");
349 }
350 }
351
352 #[test]
353 fn test_jidoka_error_display() {
354 let err = JidokaError::NanDetected { context: "test".to_string(), indices: vec![0, 2] };
355 let display = format!("{err}");
356 assert!(display.contains("NaN"));
357 assert!(display.contains("test"));
358
359 let err2 = JidokaError::BackendDivergence {
360 context: "cross".to_string(),
361 max_diff: 0.01,
362 tolerance: 0.001,
363 };
364 let display2 = format!("{err2}");
365 assert!(display2.contains("divergence"));
366 }
367
368 #[test]
373 fn test_jidoka_error_display_inf_detected() {
374 let err =
375 JidokaError::InfDetected { context: "matmul_output".to_string(), indices: vec![1, 3] };
376 let display = format!("{err}");
377 assert!(display.contains("Infinity"), "Display should contain 'Infinity', got: {display}");
378 assert!(
379 display.contains("matmul_output"),
380 "Display should contain context, got: {display}"
381 );
382 assert!(display.contains("[1, 3]"), "Display should contain indices, got: {display}");
383 }
384
385 #[test]
386 fn test_jidoka_error_display_performance_regression() {
387 let err = JidokaError::PerformanceRegression {
388 context: "avx2_dot_product".to_string(),
389 regression_pct: 15.75,
390 threshold_pct: 5.0,
391 };
392 let display = format!("{err}");
393 assert!(
394 display.contains("Performance regression"),
395 "Display should contain 'Performance regression', got: {display}"
396 );
397 assert!(
398 display.contains("avx2_dot_product"),
399 "Display should contain context, got: {display}"
400 );
401 assert!(display.contains("15.75"), "Display should contain regression_pct, got: {display}");
402 assert!(display.contains("5.00"), "Display should contain threshold_pct, got: {display}");
403 }
404
405 #[test]
406 fn test_jidoka_error_display_determinism_failure() {
407 let err = JidokaError::DeterminismFailure {
408 context: "sse2_vs_avx2".to_string(),
409 first_diff_index: 42,
410 };
411 let display = format!("{err}");
412 assert!(
413 display.contains("Determinism failure"),
414 "Display should contain 'Determinism failure', got: {display}"
415 );
416 assert!(display.contains("sse2_vs_avx2"), "Display should contain context, got: {display}");
417 assert!(display.contains("42"), "Display should contain first_diff_index, got: {display}");
418 }
419
420 #[test]
421 fn test_jidoka_error_is_std_error() {
422 let errors: Vec<Box<dyn std::error::Error>> = vec![
424 Box::new(JidokaError::NanDetected { context: "a".to_string(), indices: vec![] }),
425 Box::new(JidokaError::InfDetected { context: "b".to_string(), indices: vec![] }),
426 Box::new(JidokaError::BackendDivergence {
427 context: "c".to_string(),
428 max_diff: 0.0,
429 tolerance: 0.0,
430 }),
431 Box::new(JidokaError::PerformanceRegression {
432 context: "d".to_string(),
433 regression_pct: 0.0,
434 threshold_pct: 0.0,
435 }),
436 Box::new(JidokaError::DeterminismFailure {
437 context: "e".to_string(),
438 first_diff_index: 0,
439 }),
440 ];
441 for err in &errors {
443 assert!(
444 !err.to_string().is_empty(),
445 "Error::to_string() should produce non-empty output"
446 );
447 }
448 }
449
450 #[test]
451 fn test_empty_output_checks() {
452 let guard = JidokaGuard::nan_guard("empty_test");
453 let result = guard.check_output(&[]);
454 assert!(result.is_ok());
455 }
456
457 #[test]
458 fn test_single_element_checks() {
459 let guard = JidokaGuard::nan_guard("single_test");
460
461 assert!(guard.check_output(&[1.0]).is_ok());
462 assert!(guard.check_output(&[f32::NAN]).is_err());
463 }
464
465 #[test]
466 fn test_jidoka_condition_clone() {
467 let condition = JidokaCondition::BackendDivergence { tolerance: 1e-5 };
468 let cloned = condition.clone();
469 assert_eq!(condition, cloned);
470 }
471
472 #[test]
473 fn test_jidoka_action_eq() {
474 assert_eq!(JidokaAction::Stop, JidokaAction::Stop);
475 assert_ne!(JidokaAction::Stop, JidokaAction::LogAndContinue);
476 }
477}