use super::autocast::{is_autocast_debug, is_autocast_enabled};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AutocastCategory {
ReducedPrecision,
FullPrecision,
Passthrough,
}
pub fn autocast_category(op_name: &str) -> AutocastCategory {
match op_name {
"mm" | "matmul" | "bmm" | "linear" | "conv1d" | "conv2d" | "conv_transpose2d"
| "addmm" | "einsum" => {
AutocastCategory::ReducedPrecision
}
"sum" | "mean" | "prod" | "softmax" | "log_softmax" | "layer_norm" | "batch_norm"
| "group_norm" | "rms_norm" | "cross_entropy" | "mse_loss" | "bce_with_logits" => {
AutocastCategory::FullPrecision
}
_ => AutocastCategory::Passthrough,
}
}
pub fn should_cast_to_reduced(op_name: &str) -> bool {
is_autocast_enabled() && autocast_category(op_name) == AutocastCategory::ReducedPrecision
}
pub fn should_keep_full_precision(op_name: &str) -> bool {
is_autocast_enabled() && autocast_category(op_name) == AutocastCategory::FullPrecision
}
pub fn autocast_guard(op_name: &str) -> Option<AutocastCategory> {
if !is_autocast_enabled() {
return None;
}
let cat = autocast_category(op_name);
if is_autocast_debug() {
AUTOCAST_EVENTS.with(|events| {
events.borrow_mut().push(AutocastEvent {
op: op_name.to_owned(),
category: cat,
});
});
}
Some(cat)
}
pub fn autocast_log(op_name: &str) -> Option<AutocastCategory> {
autocast_guard(op_name)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AutocastEvent {
pub op: String,
pub category: AutocastCategory,
}
thread_local! {
static AUTOCAST_EVENTS: std::cell::RefCell<Vec<AutocastEvent>> =
const { std::cell::RefCell::new(Vec::new()) };
}
pub fn drain_autocast_events() -> Vec<AutocastEvent> {
AUTOCAST_EVENTS.with(|events| events.borrow_mut().drain(..).collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::autocast::{AutocastDtype, autocast};
#[test]
fn test_mm_is_reduced_precision() {
assert_eq!(autocast_category("mm"), AutocastCategory::ReducedPrecision);
}
#[test]
fn test_matmul_is_reduced_precision() {
assert_eq!(
autocast_category("matmul"),
AutocastCategory::ReducedPrecision
);
}
#[test]
fn test_bmm_is_reduced_precision() {
assert_eq!(autocast_category("bmm"), AutocastCategory::ReducedPrecision);
}
#[test]
fn test_linear_is_reduced_precision() {
assert_eq!(
autocast_category("linear"),
AutocastCategory::ReducedPrecision
);
}
#[test]
fn test_conv2d_is_reduced_precision() {
assert_eq!(
autocast_category("conv2d"),
AutocastCategory::ReducedPrecision
);
}
#[test]
fn test_softmax_is_full_precision() {
assert_eq!(
autocast_category("softmax"),
AutocastCategory::FullPrecision
);
}
#[test]
fn test_log_softmax_is_full_precision() {
assert_eq!(
autocast_category("log_softmax"),
AutocastCategory::FullPrecision
);
}
#[test]
fn test_layer_norm_is_full_precision() {
assert_eq!(
autocast_category("layer_norm"),
AutocastCategory::FullPrecision
);
}
#[test]
fn test_batch_norm_is_full_precision() {
assert_eq!(
autocast_category("batch_norm"),
AutocastCategory::FullPrecision
);
}
#[test]
fn test_cross_entropy_is_full_precision() {
assert_eq!(
autocast_category("cross_entropy"),
AutocastCategory::FullPrecision
);
}
#[test]
fn test_mse_loss_is_full_precision() {
assert_eq!(
autocast_category("mse_loss"),
AutocastCategory::FullPrecision
);
}
#[test]
fn test_sum_is_full_precision() {
assert_eq!(autocast_category("sum"), AutocastCategory::FullPrecision);
}
#[test]
fn test_mean_is_full_precision() {
assert_eq!(autocast_category("mean"), AutocastCategory::FullPrecision);
}
#[test]
fn test_add_is_passthrough() {
assert_eq!(autocast_category("add"), AutocastCategory::Passthrough);
}
#[test]
fn test_mul_is_passthrough() {
assert_eq!(autocast_category("mul"), AutocastCategory::Passthrough);
}
#[test]
fn test_relu_is_passthrough() {
assert_eq!(autocast_category("relu"), AutocastCategory::Passthrough);
}
#[test]
fn test_unknown_op_is_passthrough() {
assert_eq!(
autocast_category("some_custom_op"),
AutocastCategory::Passthrough
);
}
#[test]
fn test_should_cast_to_reduced_false_when_disabled() {
assert!(!should_cast_to_reduced("mm"));
assert!(!should_cast_to_reduced("matmul"));
assert!(!should_cast_to_reduced("linear"));
}
#[test]
fn test_should_cast_to_reduced_true_for_mm_when_enabled() {
autocast(AutocastDtype::F16, || {
assert!(should_cast_to_reduced("mm"));
assert!(should_cast_to_reduced("matmul"));
assert!(should_cast_to_reduced("linear"));
assert!(should_cast_to_reduced("conv2d"));
});
}
#[test]
fn test_should_cast_to_reduced_false_for_passthrough_when_enabled() {
autocast(AutocastDtype::F16, || {
assert!(!should_cast_to_reduced("add"));
assert!(!should_cast_to_reduced("relu"));
});
}
#[test]
fn test_should_cast_to_reduced_false_for_full_precision_when_enabled() {
autocast(AutocastDtype::F16, || {
assert!(!should_cast_to_reduced("softmax"));
assert!(!should_cast_to_reduced("layer_norm"));
});
}
#[test]
fn test_should_keep_full_precision_false_when_disabled() {
assert!(!should_keep_full_precision("softmax"));
assert!(!should_keep_full_precision("layer_norm"));
}
#[test]
fn test_should_keep_full_precision_true_when_enabled() {
autocast(AutocastDtype::BF16, || {
assert!(should_keep_full_precision("softmax"));
assert!(should_keep_full_precision("layer_norm"));
assert!(should_keep_full_precision("cross_entropy"));
});
}
#[test]
fn test_should_keep_full_precision_false_for_reduced_when_enabled() {
autocast(AutocastDtype::F16, || {
assert!(!should_keep_full_precision("mm"));
assert!(!should_keep_full_precision("matmul"));
});
}
#[test]
fn test_autocast_log_none_when_disabled() {
assert!(autocast_log("mm").is_none());
assert!(autocast_log("softmax").is_none());
assert!(autocast_log("add").is_none());
}
#[test]
fn test_autocast_log_returns_category_when_enabled() {
autocast(AutocastDtype::F16, || {
assert_eq!(autocast_log("mm"), Some(AutocastCategory::ReducedPrecision));
assert_eq!(
autocast_log("softmax"),
Some(AutocastCategory::FullPrecision)
);
assert_eq!(autocast_log("add"), Some(AutocastCategory::Passthrough));
});
}
#[test]
fn test_policy_active_inside_context_inactive_outside() {
assert!(!should_cast_to_reduced("mm"));
assert!(autocast_log("mm").is_none());
autocast(AutocastDtype::F16, || {
assert!(should_cast_to_reduced("mm"));
assert!(autocast_log("mm").is_some());
});
assert!(!should_cast_to_reduced("mm"));
assert!(autocast_log("mm").is_none());
}
#[test]
fn test_nested_autocast_policy_still_works() {
autocast(AutocastDtype::F16, || {
assert!(should_cast_to_reduced("mm"));
autocast(AutocastDtype::BF16, || {
assert!(should_cast_to_reduced("mm"));
assert!(should_keep_full_precision("softmax"));
});
assert!(should_cast_to_reduced("mm"));
});
assert!(!should_cast_to_reduced("mm"));
}
#[test]
fn test_autocast_guard_none_when_disabled() {
assert!(autocast_guard("mm").is_none());
assert!(autocast_guard("softmax").is_none());
}
#[test]
fn test_autocast_guard_returns_category() {
autocast(AutocastDtype::F16, || {
assert_eq!(
autocast_guard("mm"),
Some(AutocastCategory::ReducedPrecision)
);
assert_eq!(
autocast_guard("softmax"),
Some(AutocastCategory::FullPrecision)
);
assert_eq!(autocast_guard("add"), Some(AutocastCategory::Passthrough));
});
}
#[test]
fn test_autocast_guard_debug_events() {
use crate::autograd::autocast::set_autocast_debug;
drain_autocast_events();
set_autocast_debug(true);
let events = autocast(AutocastDtype::F16, || {
autocast_guard("mm");
autocast_guard("softmax");
autocast_guard("relu");
drain_autocast_events()
});
set_autocast_debug(false);
assert_eq!(events.len(), 3);
assert_eq!(events[0].op, "mm");
assert_eq!(events[0].category, AutocastCategory::ReducedPrecision);
assert_eq!(events[1].op, "softmax");
assert_eq!(events[1].category, AutocastCategory::FullPrecision);
assert_eq!(events[2].op, "relu");
assert_eq!(events[2].category, AutocastCategory::Passthrough);
}
#[test]
fn test_autocast_guard_no_events_without_debug() {
use crate::autograd::autocast::set_autocast_debug;
drain_autocast_events();
set_autocast_debug(false);
autocast(AutocastDtype::F16, || {
autocast_guard("mm");
autocast_guard("linear");
});
let events = drain_autocast_events();
assert!(
events.is_empty(),
"no events should be recorded when debug is off"
);
}
#[test]
fn test_addmm_is_reduced_precision() {
assert_eq!(
autocast_category("addmm"),
AutocastCategory::ReducedPrecision
);
}
#[test]
fn test_einsum_is_reduced_precision() {
assert_eq!(
autocast_category("einsum"),
AutocastCategory::ReducedPrecision
);
}
}