use crate::dtype::DType;
use crate::tensor::Tensor;
use std::cell::RefCell;
thread_local! {
static AUTOCAST_STATE: RefCell<AutocastState> = RefCell::new(AutocastState::default());
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AutocastMode {
None,
FP16,
BF16,
}
#[derive(Clone, Debug)]
struct AutocastState {
mode: AutocastMode,
enabled: bool,
_cache_enabled: bool,
level: usize,
}
impl Default for AutocastState {
fn default() -> Self {
Self {
mode: AutocastMode::None,
enabled: false,
_cache_enabled: true,
level: 0,
}
}
}
pub struct AutocastContext {
prev_state: AutocastState,
_device_type: String,
}
impl AutocastContext {
pub fn new(device_type: &str, enabled: bool, dtype: Option<DType>) -> Self {
let mode = if enabled {
match dtype {
Some(DType::Float16) => AutocastMode::FP16,
Some(DType::BFloat16) => AutocastMode::BF16,
_ => AutocastMode::FP16, }
} else {
AutocastMode::None
};
let prev_state = AUTOCAST_STATE.with(|state| {
let mut s = state.borrow_mut();
let prev = s.clone();
s.mode = mode;
s.enabled = enabled;
s.level += 1;
prev
});
Self {
prev_state,
_device_type: device_type.to_string(),
}
}
pub fn enter(&self) {
}
pub fn exit(&self) {
AUTOCAST_STATE.with(|state| {
let mut s = state.borrow_mut();
s.mode = self.prev_state.mode;
s.enabled = self.prev_state.enabled;
s.level = self.prev_state.level;
});
}
}
impl Drop for AutocastContext {
fn drop(&mut self) {
self.exit();
}
}
pub fn autocast(device_type: &str, enabled: bool, dtype: Option<DType>) -> AutocastContext {
AutocastContext::new(device_type, enabled, dtype)
}
pub fn is_autocast_enabled() -> bool {
AUTOCAST_STATE.with(|state| state.borrow().enabled)
}
pub fn get_autocast_mode() -> AutocastMode {
AUTOCAST_STATE.with(|state| state.borrow().mode)
}
pub fn maybe_autocast_f32(tensor: &Tensor<f32>) -> Tensor<f32> {
use crate::amp::dtype_utils::{cast_to_bf16, cast_to_fp16};
if !is_autocast_enabled() {
return tensor.clone();
}
match get_autocast_mode() {
AutocastMode::FP16 => {
cast_to_fp16(tensor)
}
AutocastMode::BF16 => {
cast_to_bf16(tensor)
}
AutocastMode::None => tensor.clone(),
}
}
#[macro_export]
macro_rules! autocast_op {
($op_name:expr, $body:expr) => {
$crate::amp::autocast::AutocastOp::apply($op_name, || $body)
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_autocast_context() {
assert!(!is_autocast_enabled());
{
let _ctx = autocast("cuda", true, Some(DType::Float16));
assert!(is_autocast_enabled());
assert_eq!(get_autocast_mode(), AutocastMode::FP16);
}
assert!(!is_autocast_enabled());
}
#[test]
fn test_nested_autocast() {
assert!(!is_autocast_enabled());
{
let _ctx1 = autocast("cuda", true, Some(DType::Float16));
assert_eq!(get_autocast_mode(), AutocastMode::FP16);
{
let _ctx2 = autocast("cuda", true, Some(DType::BFloat16));
assert_eq!(get_autocast_mode(), AutocastMode::BF16);
}
assert_eq!(get_autocast_mode(), AutocastMode::FP16);
}
assert!(!is_autocast_enabled());
}
}