use std::cell::Cell;
thread_local! {
static AUTOCAST_ENABLED: Cell<bool> = const { Cell::new(false) };
static AUTOCAST_DTYPE: Cell<AutocastDtype> = const { Cell::new(AutocastDtype::F16) };
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AutocastDtype {
F16,
BF16,
}
pub fn is_autocast_enabled() -> bool {
AUTOCAST_ENABLED.with(|e| e.get())
}
pub fn autocast_dtype() -> AutocastDtype {
AUTOCAST_DTYPE.with(|d| d.get())
}
pub fn autocast<F, R>(dtype: AutocastDtype, f: F) -> R
where
F: FnOnce() -> R,
{
let prev_enabled = AUTOCAST_ENABLED.with(|e| {
let prev = e.get();
e.set(true);
prev
});
let prev_dtype = AUTOCAST_DTYPE.with(|d| {
let prev = d.get();
d.set(dtype);
prev
});
let result = f();
AUTOCAST_DTYPE.with(|d| d.set(prev_dtype));
AUTOCAST_ENABLED.with(|e| e.set(prev_enabled));
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_autocast_default_disabled() {
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_enables() {
assert!(!is_autocast_enabled());
autocast(AutocastDtype::F16, || {
assert!(is_autocast_enabled());
});
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_nested() {
assert!(!is_autocast_enabled());
autocast(AutocastDtype::F16, || {
assert!(is_autocast_enabled());
assert_eq!(autocast_dtype(), AutocastDtype::F16);
autocast(AutocastDtype::BF16, || {
assert!(is_autocast_enabled());
assert_eq!(autocast_dtype(), AutocastDtype::BF16);
});
assert!(is_autocast_enabled());
assert_eq!(autocast_dtype(), AutocastDtype::F16);
});
assert!(!is_autocast_enabled());
}
#[test]
fn test_autocast_dtype_selection() {
autocast(AutocastDtype::BF16, || {
assert_eq!(autocast_dtype(), AutocastDtype::BF16);
});
autocast(AutocastDtype::F16, || {
assert_eq!(autocast_dtype(), AutocastDtype::F16);
});
}
#[test]
fn test_default_dtype_is_f16() {
assert_eq!(autocast_dtype(), AutocastDtype::F16);
}
}