use std::cell::RefCell;
use std::sync::Mutex;
#[derive(Debug, Clone)]
pub struct GradientContext {
pub grad_enabled: bool,
pub anomaly_detection: bool,
}
impl Default for GradientContext {
fn default() -> Self {
Self {
grad_enabled: true,
anomaly_detection: false,
}
}
}
thread_local! {
static GRADIENT_CONTEXT: RefCell<GradientContext> = RefCell::new(GradientContext::default());
}
static GLOBAL_GRADIENT_CONTEXT: Mutex<GradientContext> = Mutex::new(GradientContext {
grad_enabled: true,
anomaly_detection: false,
});
pub fn is_grad_enabled() -> bool {
GRADIENT_CONTEXT.with(|ctx| ctx.borrow().grad_enabled)
}
pub fn is_anomaly_detection_enabled() -> bool {
GRADIENT_CONTEXT.with(|ctx| ctx.borrow().anomaly_detection)
}
pub fn set_grad_enabled(enabled: bool) {
GRADIENT_CONTEXT.with(|ctx| {
ctx.borrow_mut().grad_enabled = enabled;
});
}
pub fn set_anomaly_detection(enabled: bool) {
GRADIENT_CONTEXT.with(|ctx| {
ctx.borrow_mut().anomaly_detection = enabled;
});
}
pub struct NoGradGuard {
previous_state: bool,
}
impl NoGradGuard {
pub fn new() -> Self {
let previous_state = is_grad_enabled();
set_grad_enabled(false);
Self { previous_state }
}
}
impl Drop for NoGradGuard {
fn drop(&mut self) {
set_grad_enabled(self.previous_state);
}
}
pub struct EnableGradGuard {
previous_state: bool,
}
impl EnableGradGuard {
pub fn new() -> Self {
let previous_state = is_grad_enabled();
set_grad_enabled(true);
Self { previous_state }
}
}
impl Drop for EnableGradGuard {
fn drop(&mut self) {
set_grad_enabled(self.previous_state);
}
}
pub struct AnomalyDetectionGuard {
previous_state: bool,
}
impl AnomalyDetectionGuard {
pub fn new() -> Self {
let previous_state = is_anomaly_detection_enabled();
set_anomaly_detection(true);
Self { previous_state }
}
}
impl Drop for AnomalyDetectionGuard {
fn drop(&mut self) {
set_anomaly_detection(self.previous_state);
}
}
pub fn no_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = NoGradGuard::new();
f()
}
pub fn enable_grad<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = EnableGradGuard::new();
f()
}
pub fn detect_anomaly<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = AnomalyDetectionGuard::new();
f()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_grad_guard() {
assert!(is_grad_enabled());
{
let _guard = NoGradGuard::new();
assert!(!is_grad_enabled()); }
assert!(is_grad_enabled()); }
#[test]
fn test_enable_grad_guard() {
set_grad_enabled(false);
assert!(!is_grad_enabled());
{
let _guard = EnableGradGuard::new();
assert!(is_grad_enabled()); }
assert!(!is_grad_enabled()); set_grad_enabled(true); }
#[test]
fn test_anomaly_detection_guard() {
assert!(!is_anomaly_detection_enabled());
{
let _guard = AnomalyDetectionGuard::new();
assert!(is_anomaly_detection_enabled()); }
assert!(!is_anomaly_detection_enabled()); }
#[test]
fn test_convenience_functions() {
let result = no_grad(|| {
assert!(!is_grad_enabled());
42
});
assert_eq!(result, 42);
assert!(is_grad_enabled());
set_grad_enabled(false);
let result = enable_grad(|| {
assert!(is_grad_enabled());
24
});
assert_eq!(result, 24);
assert!(!is_grad_enabled());
set_grad_enabled(true);
let result = detect_anomaly(|| {
assert!(is_anomaly_detection_enabled());
"test"
});
assert_eq!(result, "test");
assert!(!is_anomaly_detection_enabled());
}
}