use std::cell::Cell;
thread_local! {
static AUTOCAST_ENABLED: Cell<bool> = const { Cell::new(false) };
static AUTOCAST_DTYPE: Cell<AutocastDtype> = const { Cell::new(AutocastDtype::F16) };
static AUTOCAST_DEBUG: Cell<bool> = const { Cell::new(false) };
}
pub fn set_autocast_debug(enabled: bool) {
AUTOCAST_DEBUG.with(|d| d.set(enabled));
}
pub fn is_autocast_debug() -> bool {
AUTOCAST_DEBUG.with(|d| d.get())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AutocastDtype {
F16,
BF16,
}
pub fn is_autocast_enabled() -> bool {
AUTOCAST_ENABLED.with(|e| e.get())
}
pub fn autocast_dtype() -> AutocastDtype {
AUTOCAST_DTYPE.with(|d| d.get())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AutocastSnapshot {
pub enabled: bool,
pub dtype: AutocastDtype,
}
pub fn current_autocast_snapshot() -> AutocastSnapshot {
AutocastSnapshot {
enabled: is_autocast_enabled(),
dtype: autocast_dtype(),
}
}
pub fn with_autocast_state<F, R>(snapshot: AutocastSnapshot, f: F) -> R
where
F: FnOnce() -> R,
{
struct StateGuard {
prev_enabled: bool,
prev_dtype: AutocastDtype,
}
impl Drop for StateGuard {
fn drop(&mut self) {
AUTOCAST_ENABLED.with(|e| e.set(self.prev_enabled));
AUTOCAST_DTYPE.with(|d| d.set(self.prev_dtype));
}
}
let _guard = StateGuard {
prev_enabled: is_autocast_enabled(),
prev_dtype: autocast_dtype(),
};
AUTOCAST_ENABLED.with(|e| e.set(snapshot.enabled));
AUTOCAST_DTYPE.with(|d| d.set(snapshot.dtype));
f()
}
pub fn autocast<F, R>(dtype: AutocastDtype, f: F) -> R
where
F: FnOnce() -> R,
{
struct AutocastGuard {
prev_enabled: bool,
prev_dtype: AutocastDtype,
}
impl Drop for AutocastGuard {
fn drop(&mut self) {
AUTOCAST_ENABLED.with(|e| e.set(self.prev_enabled));
AUTOCAST_DTYPE.with(|d| d.set(self.prev_dtype));
}
}
let _guard = AutocastGuard {
prev_enabled: is_autocast_enabled(),
prev_dtype: autocast_dtype(),
};
AUTOCAST_ENABLED.with(|e| e.set(true));
AUTOCAST_DTYPE.with(|d| d.set(dtype));
f()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_autocast_default_disabled() {
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_enables() {
assert!(!is_autocast_enabled());
autocast(AutocastDtype::F16, || {
assert!(is_autocast_enabled());
});
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_nested() {
assert!(!is_autocast_enabled());
autocast(AutocastDtype::F16, || {
assert!(is_autocast_enabled());
assert_eq!(autocast_dtype(), AutocastDtype::F16);
autocast(AutocastDtype::BF16, || {
assert!(is_autocast_enabled());
assert_eq!(autocast_dtype(), AutocastDtype::BF16);
});
assert!(is_autocast_enabled());
assert_eq!(autocast_dtype(), AutocastDtype::F16);
});
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_dtype_selection() {
autocast(AutocastDtype::BF16, || {
assert_eq!(autocast_dtype(), AutocastDtype::BF16);
});
autocast(AutocastDtype::F16, || {
assert_eq!(autocast_dtype(), AutocastDtype::F16);
});
}
#[test]
fn test_default_dtype_is_f16() {
assert_eq!(autocast_dtype(), AutocastDtype::F16);
}
#[test]
fn test_autocast_panic_safety() {
let result = std::panic::catch_unwind(|| {
autocast(AutocastDtype::BF16, || {
assert!(is_autocast_enabled());
panic!("intentional panic inside autocast");
});
});
assert!(result.is_err());
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_debug_flag() {
assert!(!is_autocast_debug());
set_autocast_debug(true);
assert!(is_autocast_debug());
set_autocast_debug(false);
assert!(!is_autocast_debug());
}
}