use std::cell::Cell;
thread_local! {
static GRAD_ENABLED: Cell<bool> = const { Cell::new(true) };
}
pub fn is_grad_enabled() -> bool {
GRAD_ENABLED.with(|g| g.get())
}
pub fn no_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
struct NoGradGuard {
prev: bool,
}
impl Drop for NoGradGuard {
fn drop(&mut self) {
GRAD_ENABLED.with(|g| g.set(self.prev));
}
}
let _guard = NoGradGuard {
prev: is_grad_enabled(),
};
GRAD_ENABLED.with(|g| g.set(false));
f()
}
pub fn enable_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
struct EnableGradGuard {
prev: bool,
}
impl Drop for EnableGradGuard {
fn drop(&mut self) {
GRAD_ENABLED.with(|g| g.set(self.prev));
}
}
let _guard = EnableGradGuard {
prev: is_grad_enabled(),
};
GRAD_ENABLED.with(|g| g.set(true));
f()
}
pub fn set_grad_enabled(enabled: bool) {
GRAD_ENABLED.with(|cell| cell.set(enabled));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grad_enabled_default() {
assert!(is_grad_enabled());
}
#[test]
fn test_no_grad_disables() {
assert!(is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_no_grad_nested() {
assert!(is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
});
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_enable_grad_inside_no_grad() {
no_grad(|| {
assert!(!is_grad_enabled());
enable_grad(|| {
assert!(is_grad_enabled());
});
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_enable_grad_returns_value() {
let val = no_grad(|| enable_grad(|| 42));
assert_eq!(val, 42);
}
#[test]
fn test_enable_grad_when_already_enabled() {
assert!(is_grad_enabled());
let result = enable_grad(|| {
assert!(is_grad_enabled());
99
});
assert_eq!(result, 99);
assert!(is_grad_enabled());
}
#[test]
fn test_set_grad_enabled() {
assert!(is_grad_enabled());
set_grad_enabled(false);
assert!(!is_grad_enabled());
set_grad_enabled(true);
assert!(is_grad_enabled());
}
#[test]
fn test_set_grad_enabled_inside_no_grad() {
no_grad(|| {
assert!(!is_grad_enabled());
set_grad_enabled(true);
assert!(is_grad_enabled());
set_grad_enabled(false);
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_no_grad_panic_safety() {
assert!(is_grad_enabled());
let result = std::panic::catch_unwind(|| {
no_grad(|| {
assert!(!is_grad_enabled());
panic!("intentional panic inside no_grad");
});
});
assert!(result.is_err());
assert!(is_grad_enabled());
}
#[test]
fn test_enable_grad_panic_safety() {
no_grad(|| {
assert!(!is_grad_enabled());
let result = std::panic::catch_unwind(|| {
enable_grad(|| {
assert!(is_grad_enabled());
panic!("intentional panic inside enable_grad");
});
});
assert!(result.is_err());
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
}