use std::cell::Cell;
thread_local! {
static GRAD_ENABLED: Cell<bool> = const { Cell::new(true) };
}
pub fn is_grad_enabled() -> bool {
GRAD_ENABLED.with(|g| g.get())
}
pub fn no_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let prev = GRAD_ENABLED.with(|g| {
let prev = g.get();
g.set(false);
prev
});
let result = f();
GRAD_ENABLED.with(|g| g.set(prev));
result
}
pub fn enable_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
GRAD_ENABLED.with(|cell| {
let prev = cell.get();
cell.set(true);
let result = f();
cell.set(prev);
result
})
}
pub fn set_grad_enabled(enabled: bool) {
GRAD_ENABLED.with(|cell| cell.set(enabled));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grad_enabled_default() {
assert!(is_grad_enabled());
}
#[test]
fn test_no_grad_disables() {
assert!(is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_no_grad_nested() {
assert!(is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
no_grad(|| {
assert!(!is_grad_enabled());
});
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_enable_grad_inside_no_grad() {
no_grad(|| {
assert!(!is_grad_enabled());
enable_grad(|| {
assert!(is_grad_enabled());
});
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
#[test]
fn test_enable_grad_returns_value() {
let val = no_grad(|| {
enable_grad(|| 42)
});
assert_eq!(val, 42);
}
#[test]
fn test_enable_grad_when_already_enabled() {
assert!(is_grad_enabled());
let result = enable_grad(|| {
assert!(is_grad_enabled());
99
});
assert_eq!(result, 99);
assert!(is_grad_enabled());
}
#[test]
fn test_set_grad_enabled() {
assert!(is_grad_enabled());
set_grad_enabled(false);
assert!(!is_grad_enabled());
set_grad_enabled(true);
assert!(is_grad_enabled());
}
#[test]
fn test_set_grad_enabled_inside_no_grad() {
no_grad(|| {
assert!(!is_grad_enabled());
set_grad_enabled(true);
assert!(is_grad_enabled());
set_grad_enabled(false);
assert!(!is_grad_enabled());
});
assert!(is_grad_enabled());
}
}