Skip to main content

trueno/simulation/
jidoka.rs

1//! Jidoka Guard (Built-in Quality: Stop on Defect)
2//!
3//! Implements Toyota Production System's Jidoka principle:
4//! stop production when a defect is detected.
5
6/// Jidoka condition that triggers stop
7#[derive(Debug, Clone, PartialEq)]
8pub enum JidokaCondition {
9    /// NaN detected in output
10    NanDetected,
11    /// Infinity detected in output
12    InfDetected,
13    /// Cross-backend divergence exceeds tolerance
14    BackendDivergence {
15        /// Tolerance threshold
16        tolerance: f32,
17    },
18    /// Performance regression exceeds threshold
19    PerformanceRegression {
20        /// Threshold percentage
21        threshold_pct: f32,
22    },
23    /// Determinism failure (same seed, different output)
24    DeterminismFailure,
25}
26
27/// Jidoka action on condition trigger
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum JidokaAction {
30    /// Stop immediately and report
31    Stop,
32    /// Log and continue (soft Jidoka)
33    LogAndContinue,
34    /// Trigger visual diff report
35    VisualReport,
36}
37
38/// Jidoka error types
39#[derive(Debug, Clone)]
40pub enum JidokaError {
41    /// NaN values detected
42    NanDetected {
43        /// Context description
44        context: String,
45        /// Indices of NaN values
46        indices: Vec<usize>,
47    },
48    /// Infinity values detected
49    InfDetected {
50        /// Context description
51        context: String,
52        /// Indices of infinite values
53        indices: Vec<usize>,
54    },
55    /// Backend divergence detected
56    BackendDivergence {
57        /// Context description
58        context: String,
59        /// Maximum difference found
60        max_diff: f32,
61        /// Tolerance threshold
62        tolerance: f32,
63    },
64    /// Performance regression detected
65    PerformanceRegression {
66        /// Context description
67        context: String,
68        /// Actual regression percentage
69        regression_pct: f32,
70        /// Threshold percentage
71        threshold_pct: f32,
72    },
73    /// Determinism failure detected
74    DeterminismFailure {
75        /// Context description
76        context: String,
77        /// First differing index
78        first_diff_index: usize,
79    },
80}
81
82impl std::fmt::Display for JidokaError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            Self::NanDetected { context, indices } => {
86                write!(f, "Jidoka: NaN detected at {context} (indices: {indices:?})")
87            }
88            Self::InfDetected { context, indices } => {
89                write!(f, "Jidoka: Infinity detected at {context} (indices: {indices:?})")
90            }
91            Self::BackendDivergence { context, max_diff, tolerance } => {
92                write!(
93                    f,
94                    "Jidoka: Backend divergence at {context} (max_diff: {max_diff}, tolerance: {tolerance})"
95                )
96            }
97            Self::PerformanceRegression { context, regression_pct, threshold_pct } => {
98                write!(
99                    f,
100                    "Jidoka: Performance regression at {context} ({regression_pct:.2}% > {threshold_pct:.2}%)"
101                )
102            }
103            Self::DeterminismFailure { context, first_diff_index } => {
104                write!(
105                    f,
106                    "Jidoka: Determinism failure at {context} (first diff at index {first_diff_index})"
107                )
108            }
109        }
110    }
111}
112
113impl std::error::Error for JidokaError {}
114
115/// Jidoka guard for simulation tests
116///
117/// Implements Toyota Production System's Jidoka principle:
118/// stop production when a defect is detected.
119#[derive(Debug, Clone)]
120pub struct JidokaGuard {
121    /// Condition that triggers stop
122    pub condition: JidokaCondition,
123    /// Action to take on trigger
124    pub action: JidokaAction,
125    /// Context for debugging
126    pub context: String,
127}
128
129impl JidokaGuard {
130    /// Create a new Jidoka guard
131    #[must_use]
132    pub fn new(
133        condition: JidokaCondition,
134        action: JidokaAction,
135        context: impl Into<String>,
136    ) -> Self {
137        Self { condition, action, context: context.into() }
138    }
139
140    /// Create a NaN detection guard
141    #[must_use]
142    pub fn nan_guard(context: impl Into<String>) -> Self {
143        Self::new(JidokaCondition::NanDetected, JidokaAction::Stop, context)
144    }
145
146    /// Create an infinity detection guard
147    #[must_use]
148    pub fn inf_guard(context: impl Into<String>) -> Self {
149        Self::new(JidokaCondition::InfDetected, JidokaAction::Stop, context)
150    }
151
152    /// Create a backend divergence guard
153    #[must_use]
154    pub fn divergence_guard(tolerance: f32, context: impl Into<String>) -> Self {
155        Self::new(JidokaCondition::BackendDivergence { tolerance }, JidokaAction::Stop, context)
156    }
157
158    /// Check output for NaN/Inf and return error if found
159    ///
160    /// # Errors
161    ///
162    /// Returns `JidokaError` if the condition is triggered
163    pub fn check_output(&self, output: &[f32]) -> Result<(), JidokaError> {
164        match &self.condition {
165            JidokaCondition::NanDetected => {
166                let nan_indices: Vec<usize> =
167                    output.iter().enumerate().filter(|(_, x)| x.is_nan()).map(|(i, _)| i).collect();
168
169                if !nan_indices.is_empty() {
170                    return Err(JidokaError::NanDetected {
171                        context: self.context.clone(),
172                        indices: nan_indices,
173                    });
174                }
175            }
176            JidokaCondition::InfDetected => {
177                let inf_indices: Vec<usize> = output
178                    .iter()
179                    .enumerate()
180                    .filter(|(_, x)| x.is_infinite())
181                    .map(|(i, _)| i)
182                    .collect();
183
184                if !inf_indices.is_empty() {
185                    return Err(JidokaError::InfDetected {
186                        context: self.context.clone(),
187                        indices: inf_indices,
188                    });
189                }
190            }
191            JidokaCondition::BackendDivergence { .. }
192            | JidokaCondition::PerformanceRegression { .. }
193            | JidokaCondition::DeterminismFailure => {} // Handled by compare methods
194        }
195        Ok(())
196    }
197
198    /// Compare two outputs for backend divergence
199    ///
200    /// # Errors
201    ///
202    /// Returns `JidokaError` if divergence exceeds tolerance
203    pub fn check_divergence(&self, a: &[f32], b: &[f32]) -> Result<(), JidokaError> {
204        if let JidokaCondition::BackendDivergence { tolerance } = &self.condition {
205            let max_diff =
206                a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0_f32, f32::max);
207
208            if max_diff > *tolerance {
209                return Err(JidokaError::BackendDivergence {
210                    context: self.context.clone(),
211                    max_diff,
212                    tolerance: *tolerance,
213                });
214            }
215        }
216        Ok(())
217    }
218
219    /// Check for determinism (same inputs should produce same outputs)
220    ///
221    /// # Errors
222    ///
223    /// Returns `JidokaError` if outputs differ
224    pub fn check_determinism(&self, a: &[f32], b: &[f32]) -> Result<(), JidokaError> {
225        if let JidokaCondition::DeterminismFailure = &self.condition {
226            for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
227                // Use bitwise comparison for exact equality
228                if x.to_bits() != y.to_bits() {
229                    return Err(JidokaError::DeterminismFailure {
230                        context: self.context.clone(),
231                        first_diff_index: i,
232                    });
233                }
234            }
235        }
236        Ok(())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_jidoka_nan_detection() {
246        // Falsifiable claim B-027
247        let guard = JidokaGuard::nan_guard("test_operation");
248        let output_with_nan = vec![1.0, 2.0, f32::NAN, 4.0];
249
250        let result = guard.check_output(&output_with_nan);
251        assert!(result.is_err());
252
253        if let Err(JidokaError::NanDetected { indices, .. }) = result {
254            assert_eq!(indices, vec![2]);
255        } else {
256            panic!("Expected NanDetected error");
257        }
258    }
259
260    #[test]
261    fn test_jidoka_nan_no_false_positive() {
262        let guard = JidokaGuard::nan_guard("test_operation");
263        let clean_output = vec![1.0, 2.0, 3.0, 4.0];
264
265        let result = guard.check_output(&clean_output);
266        assert!(result.is_ok());
267    }
268
269    #[test]
270    fn test_jidoka_inf_detection() {
271        // Falsifiable claim B-028
272        let guard = JidokaGuard::inf_guard("test_operation");
273        let output_with_inf = vec![1.0, f32::INFINITY, 3.0, f32::NEG_INFINITY];
274
275        let result = guard.check_output(&output_with_inf);
276        assert!(result.is_err());
277
278        if let Err(JidokaError::InfDetected { indices, .. }) = result {
279            assert_eq!(indices, vec![1, 3]);
280        } else {
281            panic!("Expected InfDetected error");
282        }
283    }
284
285    #[test]
286    fn test_jidoka_divergence_detection() {
287        // Falsifiable claim A-004
288        let guard = JidokaGuard::divergence_guard(1e-5, "cross_backend");
289        let a = vec![1.0, 2.0, 3.0, 4.0];
290        let b = vec![1.0, 2.0, 3.1, 4.0]; // 0.1 diff at index 2
291
292        let result = guard.check_divergence(&a, &b);
293        assert!(result.is_err());
294
295        if let Err(JidokaError::BackendDivergence { max_diff, .. }) = result {
296            assert!((max_diff - 0.1).abs() < 1e-6);
297        } else {
298            panic!("Expected BackendDivergence error");
299        }
300    }
301
302    #[test]
303    fn test_jidoka_divergence_within_tolerance() {
304        let guard = JidokaGuard::divergence_guard(1e-5, "cross_backend");
305        let a = vec![1.0, 2.0, 3.0, 4.0];
306        let b = vec![1.0, 2.0, 3.0 + 1e-7, 4.0]; // Within tolerance
307
308        let result = guard.check_divergence(&a, &b);
309        assert!(result.is_ok());
310    }
311
312    #[test]
313    fn test_jidoka_determinism_check() {
314        // Falsifiable claim B-017
315        let guard = JidokaGuard::new(
316            JidokaCondition::DeterminismFailure,
317            JidokaAction::Stop,
318            "determinism_test",
319        );
320
321        let a = vec![1.0, 2.0, 3.0, 4.0];
322        let b = vec![1.0, 2.0, 3.0, 4.0];
323
324        let result = guard.check_determinism(&a, &b);
325        assert!(result.is_ok());
326    }
327
328    #[test]
329    fn test_jidoka_determinism_failure() {
330        let guard = JidokaGuard::new(
331            JidokaCondition::DeterminismFailure,
332            JidokaAction::Stop,
333            "determinism_test",
334        );
335
336        let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
337        let b: Vec<f32> = vec![1.0, 2.0, 3.000_001, 4.0]; // Different bit pattern
338
339        // Verify they actually have different bits
340        assert_ne!(a[2].to_bits(), b[2].to_bits(), "Test values must differ");
341
342        let result = guard.check_determinism(&a, &b);
343        assert!(result.is_err());
344
345        if let Err(JidokaError::DeterminismFailure { first_diff_index, .. }) = result {
346            assert_eq!(first_diff_index, 2);
347        } else {
348            panic!("Expected DeterminismFailure error");
349        }
350    }
351
352    #[test]
353    fn test_jidoka_error_display() {
354        let err = JidokaError::NanDetected { context: "test".to_string(), indices: vec![0, 2] };
355        let display = format!("{err}");
356        assert!(display.contains("NaN"));
357        assert!(display.contains("test"));
358
359        let err2 = JidokaError::BackendDivergence {
360            context: "cross".to_string(),
361            max_diff: 0.01,
362            tolerance: 0.001,
363        };
364        let display2 = format!("{err2}");
365        assert!(display2.contains("divergence"));
366    }
367
368    // ================================================================
369    // Coverage tests for JidokaError Display — missing variants
370    // ================================================================
371
372    #[test]
373    fn test_jidoka_error_display_inf_detected() {
374        let err =
375            JidokaError::InfDetected { context: "matmul_output".to_string(), indices: vec![1, 3] };
376        let display = format!("{err}");
377        assert!(display.contains("Infinity"), "Display should contain 'Infinity', got: {display}");
378        assert!(
379            display.contains("matmul_output"),
380            "Display should contain context, got: {display}"
381        );
382        assert!(display.contains("[1, 3]"), "Display should contain indices, got: {display}");
383    }
384
385    #[test]
386    fn test_jidoka_error_display_performance_regression() {
387        let err = JidokaError::PerformanceRegression {
388            context: "avx2_dot_product".to_string(),
389            regression_pct: 15.75,
390            threshold_pct: 5.0,
391        };
392        let display = format!("{err}");
393        assert!(
394            display.contains("Performance regression"),
395            "Display should contain 'Performance regression', got: {display}"
396        );
397        assert!(
398            display.contains("avx2_dot_product"),
399            "Display should contain context, got: {display}"
400        );
401        assert!(display.contains("15.75"), "Display should contain regression_pct, got: {display}");
402        assert!(display.contains("5.00"), "Display should contain threshold_pct, got: {display}");
403    }
404
405    #[test]
406    fn test_jidoka_error_display_determinism_failure() {
407        let err = JidokaError::DeterminismFailure {
408            context: "sse2_vs_avx2".to_string(),
409            first_diff_index: 42,
410        };
411        let display = format!("{err}");
412        assert!(
413            display.contains("Determinism failure"),
414            "Display should contain 'Determinism failure', got: {display}"
415        );
416        assert!(display.contains("sse2_vs_avx2"), "Display should contain context, got: {display}");
417        assert!(display.contains("42"), "Display should contain first_diff_index, got: {display}");
418    }
419
420    #[test]
421    fn test_jidoka_error_is_std_error() {
422        // Verify the std::error::Error impl works for all variants
423        let errors: Vec<Box<dyn std::error::Error>> = vec![
424            Box::new(JidokaError::NanDetected { context: "a".to_string(), indices: vec![] }),
425            Box::new(JidokaError::InfDetected { context: "b".to_string(), indices: vec![] }),
426            Box::new(JidokaError::BackendDivergence {
427                context: "c".to_string(),
428                max_diff: 0.0,
429                tolerance: 0.0,
430            }),
431            Box::new(JidokaError::PerformanceRegression {
432                context: "d".to_string(),
433                regression_pct: 0.0,
434                threshold_pct: 0.0,
435            }),
436            Box::new(JidokaError::DeterminismFailure {
437                context: "e".to_string(),
438                first_diff_index: 0,
439            }),
440        ];
441        // All variants should produce non-empty Display output via Error trait
442        for err in &errors {
443            assert!(
444                !err.to_string().is_empty(),
445                "Error::to_string() should produce non-empty output"
446            );
447        }
448    }
449
450    #[test]
451    fn test_empty_output_checks() {
452        let guard = JidokaGuard::nan_guard("empty_test");
453        let result = guard.check_output(&[]);
454        assert!(result.is_ok());
455    }
456
457    #[test]
458    fn test_single_element_checks() {
459        let guard = JidokaGuard::nan_guard("single_test");
460
461        assert!(guard.check_output(&[1.0]).is_ok());
462        assert!(guard.check_output(&[f32::NAN]).is_err());
463    }
464
465    #[test]
466    fn test_jidoka_condition_clone() {
467        let condition = JidokaCondition::BackendDivergence { tolerance: 1e-5 };
468        let cloned = condition.clone();
469        assert_eq!(condition, cloned);
470    }
471
472    #[test]
473    fn test_jidoka_action_eq() {
474        assert_eq!(JidokaAction::Stop, JidokaAction::Stop);
475        assert_ne!(JidokaAction::Stop, JidokaAction::LogAndContinue);
476    }
477}