Skip to main content

trueno/blis/
jidoka.rs

1#![allow(missing_docs)]
2//! Jidoka (Autonomation) - Stop on Defect
3//!
4//! Runtime guards that detect and halt computation on numerical errors.
5//! Part of Toyota Production System integration in BLIS.
6//!
7//! # Philosophy
8//!
9//! Jidoka (自働化) is a Toyota Production System principle meaning "automation
10//! with a human touch." When a defect is detected, the process stops immediately
11//! rather than propagating bad data downstream.
12//!
13//! # Usage
14//!
15//! ```
16//! use trueno::blis::jidoka::{JidokaGuard, JidokaError};
17//!
18//! let guard = JidokaGuard::strict();
19//! guard.check_input(1.0, "matrix_a")?;
20//! let computed = 1.0f32;
21//! let expected = 1.0f32;
22//! guard.validate(computed, expected)?;
23//! # Ok::<(), JidokaError>(())
24//! ```
25
26/// Jidoka error types for runtime validation
27#[derive(Debug, Clone, PartialEq)]
28pub enum JidokaError {
29    /// Numerical deviation beyond acceptable threshold
30    NumericalDeviation { computed: f32, expected: f32, relative_error: f32 },
31    /// NaN detected in computation
32    NaNDetected { location: &'static str },
33    /// Infinity detected in computation
34    InfDetected { location: &'static str },
35    /// Dimension mismatch
36    DimensionMismatch { expected: (usize, usize, usize), actual: (usize, usize, usize) },
37}
38
39impl std::fmt::Display for JidokaError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Self::NumericalDeviation { computed, expected, relative_error } => {
43                write!(
44                    f,
45                    "Jidoka: numerical deviation - computed={}, expected={}, error={}",
46                    computed, expected, relative_error
47                )
48            }
49            Self::NaNDetected { location } => {
50                write!(f, "Jidoka: NaN detected at {}", location)
51            }
52            Self::InfDetected { location } => {
53                write!(f, "Jidoka: Inf detected at {}", location)
54            }
55            Self::DimensionMismatch { expected, actual } => {
56                write!(f, "Jidoka: dimension mismatch - expected {:?}, got {:?}", expected, actual)
57            }
58        }
59    }
60}
61
62impl std::error::Error for JidokaError {}
63
64/// Jidoka guard for runtime validation
65#[derive(Debug, Clone)]
66pub struct JidokaGuard {
67    /// Maximum allowed relative error
68    pub epsilon: f32,
69    /// Whether to check for NaN/Inf
70    pub check_special: bool,
71    /// Sample rate (check every N outputs)
72    pub sample_rate: usize,
73}
74
75impl Default for JidokaGuard {
76    fn default() -> Self {
77        Self {
78            epsilon: 1e-5,
79            check_special: true,
80            sample_rate: 1000, // Check every 1000th output in release
81        }
82    }
83}
84
85impl JidokaGuard {
86    /// Create a strict guard for testing (checks every output)
87    pub fn strict() -> Self {
88        Self { epsilon: 1e-6, check_special: true, sample_rate: 1 }
89    }
90
91    /// Validate a computed value against expected
92    #[inline]
93    pub fn validate(&self, computed: f32, expected: f32) -> Result<(), JidokaError> {
94        if self.check_special {
95            if computed.is_nan() {
96                return Err(JidokaError::NaNDetected { location: "output" });
97            }
98            if computed.is_infinite() {
99                return Err(JidokaError::InfDetected { location: "output" });
100            }
101        }
102
103        let abs_diff = (computed - expected).abs();
104        let max_abs = computed.abs().max(expected.abs()).max(1e-10);
105        let relative_error = abs_diff / max_abs;
106
107        if relative_error > self.epsilon {
108            return Err(JidokaError::NumericalDeviation { computed, expected, relative_error });
109        }
110
111        Ok(())
112    }
113
114    /// Check input for NaN/Inf
115    #[inline]
116    pub fn check_input(&self, value: f32, location: &'static str) -> Result<(), JidokaError> {
117        if !self.check_special {
118            return Ok(());
119        }
120        if value.is_nan() {
121            return Err(JidokaError::NaNDetected { location });
122        }
123        if value.is_infinite() {
124            return Err(JidokaError::InfDetected { location });
125        }
126        Ok(())
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_jidoka_default() {
136        let guard = JidokaGuard::default();
137        assert!((guard.epsilon - 1e-5).abs() < 1e-10);
138        assert!(guard.check_special);
139        assert_eq!(guard.sample_rate, 1000);
140    }
141
142    #[test]
143    fn test_jidoka_strict() {
144        let guard = JidokaGuard::strict();
145        assert!((guard.epsilon - 1e-6).abs() < 1e-10);
146        assert!(guard.check_special);
147        assert_eq!(guard.sample_rate, 1);
148    }
149
150    #[test]
151    fn test_validate_pass() {
152        let guard = JidokaGuard::default();
153        assert!(guard.validate(1.0, 1.0).is_ok());
154        assert!(guard.validate(1.0, 1.000001).is_ok());
155    }
156
157    #[test]
158    fn test_validate_nan() {
159        let guard = JidokaGuard::default();
160        let result = guard.validate(f32::NAN, 1.0);
161        assert!(matches!(result, Err(JidokaError::NaNDetected { .. })));
162    }
163
164    #[test]
165    fn test_validate_inf() {
166        let guard = JidokaGuard::default();
167        let result = guard.validate(f32::INFINITY, 1.0);
168        assert!(matches!(result, Err(JidokaError::InfDetected { .. })));
169    }
170
171    #[test]
172    fn test_validate_deviation() {
173        let guard = JidokaGuard::strict();
174        let result = guard.validate(1.0, 2.0);
175        assert!(matches!(result, Err(JidokaError::NumericalDeviation { .. })));
176    }
177
178    #[test]
179    fn test_check_input_nan() {
180        let guard = JidokaGuard::default();
181        let result = guard.check_input(f32::NAN, "test");
182        assert!(matches!(result, Err(JidokaError::NaNDetected { .. })));
183    }
184
185    #[test]
186    fn test_check_input_inf() {
187        let guard = JidokaGuard::default();
188        let result = guard.check_input(f32::INFINITY, "test");
189        assert!(matches!(result, Err(JidokaError::InfDetected { .. })));
190    }
191
192    #[test]
193    fn test_error_display() {
194        let err = JidokaError::NaNDetected { location: "test" };
195        assert!(format!("{}", err).contains("NaN"));
196
197        let err = JidokaError::InfDetected { location: "test" };
198        assert!(format!("{}", err).contains("Inf"));
199
200        let err =
201            JidokaError::NumericalDeviation { computed: 1.0, expected: 2.0, relative_error: 0.5 };
202        assert!(format!("{}", err).contains("deviation"));
203
204        let err = JidokaError::DimensionMismatch { expected: (1, 2, 3), actual: (4, 5, 6) };
205        assert!(format!("{}", err).contains("mismatch"));
206    }
207}