#[derive(Debug, Clone, PartialEq)]
pub enum JidokaCondition {
NanDetected,
InfDetected,
BackendDivergence {
tolerance: f32,
},
PerformanceRegression {
threshold_pct: f32,
},
DeterminismFailure,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JidokaAction {
Stop,
LogAndContinue,
VisualReport,
}
#[derive(Debug, Clone)]
pub enum JidokaError {
NanDetected {
context: String,
indices: Vec<usize>,
},
InfDetected {
context: String,
indices: Vec<usize>,
},
BackendDivergence {
context: String,
max_diff: f32,
tolerance: f32,
},
PerformanceRegression {
context: String,
regression_pct: f32,
threshold_pct: f32,
},
DeterminismFailure {
context: String,
first_diff_index: usize,
},
}
impl std::fmt::Display for JidokaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NanDetected { context, indices } => {
write!(f, "Jidoka: NaN detected at {context} (indices: {indices:?})")
}
Self::InfDetected { context, indices } => {
write!(f, "Jidoka: Infinity detected at {context} (indices: {indices:?})")
}
Self::BackendDivergence { context, max_diff, tolerance } => {
write!(
f,
"Jidoka: Backend divergence at {context} (max_diff: {max_diff}, tolerance: {tolerance})"
)
}
Self::PerformanceRegression { context, regression_pct, threshold_pct } => {
write!(
f,
"Jidoka: Performance regression at {context} ({regression_pct:.2}% > {threshold_pct:.2}%)"
)
}
Self::DeterminismFailure { context, first_diff_index } => {
write!(
f,
"Jidoka: Determinism failure at {context} (first diff at index {first_diff_index})"
)
}
}
}
}
impl std::error::Error for JidokaError {}
#[derive(Debug, Clone)]
pub struct JidokaGuard {
pub condition: JidokaCondition,
pub action: JidokaAction,
pub context: String,
}
impl JidokaGuard {
#[must_use]
pub fn new(
condition: JidokaCondition,
action: JidokaAction,
context: impl Into<String>,
) -> Self {
Self { condition, action, context: context.into() }
}
#[must_use]
pub fn nan_guard(context: impl Into<String>) -> Self {
Self::new(JidokaCondition::NanDetected, JidokaAction::Stop, context)
}
#[must_use]
pub fn inf_guard(context: impl Into<String>) -> Self {
Self::new(JidokaCondition::InfDetected, JidokaAction::Stop, context)
}
#[must_use]
pub fn divergence_guard(tolerance: f32, context: impl Into<String>) -> Self {
Self::new(JidokaCondition::BackendDivergence { tolerance }, JidokaAction::Stop, context)
}
pub fn check_output(&self, output: &[f32]) -> Result<(), JidokaError> {
match &self.condition {
JidokaCondition::NanDetected => {
let nan_indices: Vec<usize> =
output.iter().enumerate().filter(|(_, x)| x.is_nan()).map(|(i, _)| i).collect();
if !nan_indices.is_empty() {
return Err(JidokaError::NanDetected {
context: self.context.clone(),
indices: nan_indices,
});
}
}
JidokaCondition::InfDetected => {
let inf_indices: Vec<usize> = output
.iter()
.enumerate()
.filter(|(_, x)| x.is_infinite())
.map(|(i, _)| i)
.collect();
if !inf_indices.is_empty() {
return Err(JidokaError::InfDetected {
context: self.context.clone(),
indices: inf_indices,
});
}
}
JidokaCondition::BackendDivergence { .. }
| JidokaCondition::PerformanceRegression { .. }
| JidokaCondition::DeterminismFailure => {} }
Ok(())
}
pub fn check_divergence(&self, a: &[f32], b: &[f32]) -> Result<(), JidokaError> {
if let JidokaCondition::BackendDivergence { tolerance } = &self.condition {
let max_diff =
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0_f32, f32::max);
if max_diff > *tolerance {
return Err(JidokaError::BackendDivergence {
context: self.context.clone(),
max_diff,
tolerance: *tolerance,
});
}
}
Ok(())
}
pub fn check_determinism(&self, a: &[f32], b: &[f32]) -> Result<(), JidokaError> {
if let JidokaCondition::DeterminismFailure = &self.condition {
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
if x.to_bits() != y.to_bits() {
return Err(JidokaError::DeterminismFailure {
context: self.context.clone(),
first_diff_index: i,
});
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jidoka_nan_detection() {
let guard = JidokaGuard::nan_guard("test_operation");
let output_with_nan = vec![1.0, 2.0, f32::NAN, 4.0];
let result = guard.check_output(&output_with_nan);
assert!(result.is_err());
if let Err(JidokaError::NanDetected { indices, .. }) = result {
assert_eq!(indices, vec![2]);
} else {
panic!("Expected NanDetected error");
}
}
#[test]
fn test_jidoka_nan_no_false_positive() {
let guard = JidokaGuard::nan_guard("test_operation");
let clean_output = vec![1.0, 2.0, 3.0, 4.0];
let result = guard.check_output(&clean_output);
assert!(result.is_ok());
}
#[test]
fn test_jidoka_inf_detection() {
let guard = JidokaGuard::inf_guard("test_operation");
let output_with_inf = vec![1.0, f32::INFINITY, 3.0, f32::NEG_INFINITY];
let result = guard.check_output(&output_with_inf);
assert!(result.is_err());
if let Err(JidokaError::InfDetected { indices, .. }) = result {
assert_eq!(indices, vec![1, 3]);
} else {
panic!("Expected InfDetected error");
}
}
#[test]
fn test_jidoka_divergence_detection() {
let guard = JidokaGuard::divergence_guard(1e-5, "cross_backend");
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0, 3.1, 4.0];
let result = guard.check_divergence(&a, &b);
assert!(result.is_err());
if let Err(JidokaError::BackendDivergence { max_diff, .. }) = result {
assert!((max_diff - 0.1).abs() < 1e-6);
} else {
panic!("Expected BackendDivergence error");
}
}
#[test]
fn test_jidoka_divergence_within_tolerance() {
let guard = JidokaGuard::divergence_guard(1e-5, "cross_backend");
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0, 3.0 + 1e-7, 4.0];
let result = guard.check_divergence(&a, &b);
assert!(result.is_ok());
}
#[test]
fn test_jidoka_determinism_check() {
let guard = JidokaGuard::new(
JidokaCondition::DeterminismFailure,
JidokaAction::Stop,
"determinism_test",
);
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 2.0, 3.0, 4.0];
let result = guard.check_determinism(&a, &b);
assert!(result.is_ok());
}
#[test]
fn test_jidoka_determinism_failure() {
let guard = JidokaGuard::new(
JidokaCondition::DeterminismFailure,
JidokaAction::Stop,
"determinism_test",
);
let a: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let b: Vec<f32> = vec![1.0, 2.0, 3.000_001, 4.0];
assert_ne!(a[2].to_bits(), b[2].to_bits(), "Test values must differ");
let result = guard.check_determinism(&a, &b);
assert!(result.is_err());
if let Err(JidokaError::DeterminismFailure { first_diff_index, .. }) = result {
assert_eq!(first_diff_index, 2);
} else {
panic!("Expected DeterminismFailure error");
}
}
#[test]
fn test_jidoka_error_display() {
let err = JidokaError::NanDetected { context: "test".to_string(), indices: vec![0, 2] };
let display = format!("{err}");
assert!(display.contains("NaN"));
assert!(display.contains("test"));
let err2 = JidokaError::BackendDivergence {
context: "cross".to_string(),
max_diff: 0.01,
tolerance: 0.001,
};
let display2 = format!("{err2}");
assert!(display2.contains("divergence"));
}
#[test]
fn test_jidoka_error_display_inf_detected() {
let err =
JidokaError::InfDetected { context: "matmul_output".to_string(), indices: vec![1, 3] };
let display = format!("{err}");
assert!(display.contains("Infinity"), "Display should contain 'Infinity', got: {display}");
assert!(
display.contains("matmul_output"),
"Display should contain context, got: {display}"
);
assert!(display.contains("[1, 3]"), "Display should contain indices, got: {display}");
}
#[test]
fn test_jidoka_error_display_performance_regression() {
let err = JidokaError::PerformanceRegression {
context: "avx2_dot_product".to_string(),
regression_pct: 15.75,
threshold_pct: 5.0,
};
let display = format!("{err}");
assert!(
display.contains("Performance regression"),
"Display should contain 'Performance regression', got: {display}"
);
assert!(
display.contains("avx2_dot_product"),
"Display should contain context, got: {display}"
);
assert!(display.contains("15.75"), "Display should contain regression_pct, got: {display}");
assert!(display.contains("5.00"), "Display should contain threshold_pct, got: {display}");
}
#[test]
fn test_jidoka_error_display_determinism_failure() {
let err = JidokaError::DeterminismFailure {
context: "sse2_vs_avx2".to_string(),
first_diff_index: 42,
};
let display = format!("{err}");
assert!(
display.contains("Determinism failure"),
"Display should contain 'Determinism failure', got: {display}"
);
assert!(display.contains("sse2_vs_avx2"), "Display should contain context, got: {display}");
assert!(display.contains("42"), "Display should contain first_diff_index, got: {display}");
}
#[test]
fn test_jidoka_error_is_std_error() {
let errors: Vec<Box<dyn std::error::Error>> = vec![
Box::new(JidokaError::NanDetected { context: "a".to_string(), indices: vec![] }),
Box::new(JidokaError::InfDetected { context: "b".to_string(), indices: vec![] }),
Box::new(JidokaError::BackendDivergence {
context: "c".to_string(),
max_diff: 0.0,
tolerance: 0.0,
}),
Box::new(JidokaError::PerformanceRegression {
context: "d".to_string(),
regression_pct: 0.0,
threshold_pct: 0.0,
}),
Box::new(JidokaError::DeterminismFailure {
context: "e".to_string(),
first_diff_index: 0,
}),
];
for err in &errors {
assert!(
!err.to_string().is_empty(),
"Error::to_string() should produce non-empty output"
);
}
}
#[test]
fn test_empty_output_checks() {
let guard = JidokaGuard::nan_guard("empty_test");
let result = guard.check_output(&[]);
assert!(result.is_ok());
}
#[test]
fn test_single_element_checks() {
let guard = JidokaGuard::nan_guard("single_test");
assert!(guard.check_output(&[1.0]).is_ok());
assert!(guard.check_output(&[f32::NAN]).is_err());
}
#[test]
fn test_jidoka_condition_clone() {
let condition = JidokaCondition::BackendDivergence { tolerance: 1e-5 };
let cloned = condition.clone();
assert_eq!(condition, cloned);
}
#[test]
fn test_jidoka_action_eq() {
assert_eq!(JidokaAction::Stop, JidokaAction::Stop);
assert_ne!(JidokaAction::Stop, JidokaAction::LogAndContinue);
}
}