use std::cell::Cell;
use std::fmt;
thread_local! {
static ANOMALY_ENABLED: Cell<bool> = const { Cell::new(false) };
}
pub struct AnomalyMode;
impl AnomalyMode {
pub fn enable() {
ANOMALY_ENABLED.with(|c| c.set(true));
}
pub fn disable() {
ANOMALY_ENABLED.with(|c| c.set(false));
}
pub fn is_enabled() -> bool {
ANOMALY_ENABLED.with(|c| c.get())
}
}
pub fn detect_anomaly<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
struct AnomalyGuard {
prev: bool,
}
impl Drop for AnomalyGuard {
fn drop(&mut self) {
ANOMALY_ENABLED.with(|c| c.set(self.prev));
}
}
let _guard = AnomalyGuard {
prev: AnomalyMode::is_enabled(),
};
AnomalyMode::enable();
f()
}
#[derive(Clone)]
pub struct ForwardBacktrace {
trace: String,
}
impl ForwardBacktrace {
pub fn capture_if_enabled() -> Option<Self> {
if !AnomalyMode::is_enabled() {
return None;
}
let bt = std::backtrace::Backtrace::capture();
Some(Self {
trace: bt.to_string(),
})
}
pub fn trace(&self) -> &str {
&self.trace
}
}
impl fmt::Debug for ForwardBacktrace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ForwardBacktrace")
.field("trace", &"<backtrace>")
.finish()
}
}
impl fmt::Display for ForwardBacktrace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Forward-pass backtrace:\n{}", self.trace)
}
}
pub fn check_gradient_anomaly<T: crate::dtype::Float>(
grad: &crate::tensor::Tensor<T>,
op_name: &str,
forward_bt: Option<&ForwardBacktrace>,
) -> crate::error::FerrotorchResult<()> {
if !AnomalyMode::is_enabled() {
return Ok(());
}
if grad.is_cuda() {
return Ok(());
}
let data = grad.data()?;
let has_nan = data.iter().any(|v| v.is_nan());
let has_inf = data.iter().any(|v| v.is_infinite());
if has_nan || has_inf {
let anomaly_kind = if has_nan && has_inf {
"NaN and Inf"
} else if has_nan {
"NaN"
} else {
"Inf"
};
let bt_msg = match forward_bt {
Some(bt) => format!("\n\n{bt}"),
None => String::from(
"\n\n(no forward backtrace available — was anomaly mode enabled during forward pass?)",
),
};
return Err(crate::error::FerrotorchError::InvalidArgument {
message: format!("anomaly detected: {anomaly_kind} in gradient of {op_name}{bt_msg}"),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_anomaly_mode_default_off() {
AnomalyMode::disable();
assert!(!AnomalyMode::is_enabled());
}
#[test]
fn test_anomaly_mode_enable_disable() {
AnomalyMode::enable();
assert!(AnomalyMode::is_enabled());
AnomalyMode::disable();
assert!(!AnomalyMode::is_enabled());
}
#[test]
fn test_detect_anomaly_scoped() {
AnomalyMode::disable();
assert!(!AnomalyMode::is_enabled());
detect_anomaly(|| {
assert!(AnomalyMode::is_enabled());
});
assert!(!AnomalyMode::is_enabled());
}
#[test]
fn test_detect_anomaly_panic_safety() {
AnomalyMode::disable();
let result = std::panic::catch_unwind(|| {
detect_anomaly(|| {
assert!(AnomalyMode::is_enabled());
panic!("intentional panic inside detect_anomaly");
});
});
assert!(result.is_err());
assert!(!AnomalyMode::is_enabled());
}
#[test]
fn test_detect_anomaly_nested() {
AnomalyMode::disable();
detect_anomaly(|| {
assert!(AnomalyMode::is_enabled());
detect_anomaly(|| {
assert!(AnomalyMode::is_enabled());
});
assert!(AnomalyMode::is_enabled());
});
assert!(!AnomalyMode::is_enabled());
}
#[test]
fn test_forward_backtrace_capture_when_disabled() {
AnomalyMode::disable();
assert!(ForwardBacktrace::capture_if_enabled().is_none());
}
#[test]
fn test_forward_backtrace_capture_when_enabled() {
AnomalyMode::enable();
let bt = ForwardBacktrace::capture_if_enabled();
AnomalyMode::disable();
assert!(bt.is_some());
assert!(!bt.unwrap().trace().is_empty());
}
#[test]
fn test_check_gradient_anomaly_clean() {
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
AnomalyMode::enable();
let grad =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0, 2.0, 3.0]), vec![3], false)
.unwrap();
let result = check_gradient_anomaly(&grad, "TestOp", None);
AnomalyMode::disable();
assert!(result.is_ok());
}
#[test]
fn test_check_gradient_anomaly_nan() {
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
AnomalyMode::enable();
let grad = Tensor::<f32>::from_storage(
TensorStorage::cpu(vec![1.0, f32::NAN, 3.0]),
vec![3],
false,
)
.unwrap();
let result = check_gradient_anomaly(&grad, "TestOp", None);
AnomalyMode::disable();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("NaN"));
assert!(msg.contains("TestOp"));
}
#[test]
fn test_check_gradient_anomaly_inf() {
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
AnomalyMode::enable();
let grad = Tensor::<f32>::from_storage(
TensorStorage::cpu(vec![1.0, f32::INFINITY, 3.0]),
vec![3],
false,
)
.unwrap();
let result = check_gradient_anomaly(&grad, "TestOp", None);
AnomalyMode::disable();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("Inf"));
}
#[test]
fn test_check_gradient_anomaly_with_backtrace() {
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
AnomalyMode::enable();
let bt = ForwardBacktrace::capture_if_enabled().unwrap();
let grad =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![f32::NAN]), vec![], false).unwrap();
let result = check_gradient_anomaly(&grad, "BadOp", Some(&bt));
AnomalyMode::disable();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("Forward-pass backtrace"));
}
#[test]
fn test_check_gradient_anomaly_skipped_when_disabled() {
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
AnomalyMode::disable();
let grad =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![f32::NAN]), vec![], false).unwrap();
assert!(check_gradient_anomaly(&grad, "TestOp", None).is_ok());
}
}