use metal::{CommandBuffer, Device};
use std::collections::HashMap;
use crate::metal::{mps::MPSDataType, MetalBuffer, Result};
pub struct MPSMixedPrecision {
device: Device,
loss_scaling: f32,
initial_loss_scale: f32,
loss_scale_factor: f32,
scale_window: usize,
min_loss_scale: f32,
max_loss_scale: f32,
consecutive_unskipped: usize,
enabled: bool,
found_inf: bool,
scale_growth_tracker: usize,
}
impl MPSMixedPrecision {
pub fn new(device: &Device) -> Self {
Self {
device: device.clone(),
loss_scaling: 65536.0, initial_loss_scale: 65536.0,
loss_scale_factor: 2.0,
scale_window: 2000,
min_loss_scale: 1.0,
max_loss_scale: 65536.0 * 65536.0, consecutive_unskipped: 0,
enabled: true,
found_inf: false,
scale_growth_tracker: 0,
}
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn get_loss_scale(&self) -> f32 {
if self.enabled {
self.loss_scaling
} else {
1.0
}
}
pub fn scale_loss(
&self,
command_buffer: &CommandBuffer,
loss: &MetalBuffer,
scaled_loss: &MetalBuffer,
) -> Result<()> {
if !self.enabled {
return self.copy_buffer(command_buffer, loss, scaled_loss);
}
self.scale_tensor(command_buffer, loss, self.loss_scaling, scaled_loss)
}
pub fn unscale_gradients(
&mut self,
command_buffer: &CommandBuffer,
gradients: &[MetalBuffer],
unscaled_gradients: &[MetalBuffer],
) -> Result<bool> {
if !self.enabled {
for (grad, unscaled) in gradients.iter().zip(unscaled_gradients.iter()) {
self.copy_buffer(command_buffer, grad, unscaled)?;
}
return Ok(true);
}
self.found_inf = false;
for gradient in gradients {
if self.has_inf_or_nan(command_buffer, gradient)? {
self.found_inf = true;
break;
}
}
if self.found_inf {
self.consecutive_unskipped = 0;
self.reduce_loss_scale();
return Ok(false);
}
let inv_scale = 1.0 / self.loss_scaling;
for (grad, unscaled) in gradients.iter().zip(unscaled_gradients.iter()) {
self.scale_tensor(command_buffer, grad, inv_scale, unscaled)?;
}
self.consecutive_unskipped += 1;
self.update_loss_scale();
Ok(true)
}
pub fn to_half_precision(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
self.cast_tensor(command_buffer, input, output, MPSDataType::Float16)
}
pub fn to_full_precision(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
output: &MetalBuffer,
) -> Result<()> {
self.cast_tensor(command_buffer, input, output, MPSDataType::Float32)
}
fn reduce_loss_scale(&mut self) {
self.loss_scaling = (self.loss_scaling / self.loss_scale_factor).max(self.min_loss_scale);
self.scale_growth_tracker = 0;
}
fn update_loss_scale(&mut self) {
if self.consecutive_unskipped >= self.scale_window {
self.loss_scaling =
(self.loss_scaling * self.loss_scale_factor).min(self.max_loss_scale);
self.consecutive_unskipped = 0;
self.scale_growth_tracker += 1;
}
}
fn scale_tensor(
&self,
command_buffer: &CommandBuffer,
input: &MetalBuffer,
_scale: f32,
output: &MetalBuffer,
) -> Result<()> {
self.copy_buffer(command_buffer, input, output)
}
fn copy_buffer(
&self,
_command_buffer: &CommandBuffer,
_src: &MetalBuffer,
_dst: &MetalBuffer,
) -> Result<()> {
Ok(())
}
fn cast_tensor(
&self,
_command_buffer: &CommandBuffer,
_input: &MetalBuffer,
_output: &MetalBuffer,
_target_type: MPSDataType,
) -> Result<()> {
Ok(())
}
fn has_inf_or_nan(
&self,
_command_buffer: &CommandBuffer,
_tensor: &MetalBuffer,
) -> Result<bool> {
Ok(false)
}
pub fn get_stats(&self) -> MixedPrecisionStats {
MixedPrecisionStats {
current_loss_scale: self.loss_scaling,
consecutive_unskipped: self.consecutive_unskipped,
scale_growth_tracker: self.scale_growth_tracker,
found_inf_last_step: self.found_inf,
enabled: self.enabled,
}
}
pub fn reset(&mut self) {
self.loss_scaling = self.initial_loss_scale;
self.consecutive_unskipped = 0;
self.scale_growth_tracker = 0;
self.found_inf = false;
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionStats {
pub current_loss_scale: f32,
pub consecutive_unskipped: usize,
pub scale_growth_tracker: usize,
pub found_inf_last_step: bool,
pub enabled: bool,
}
pub struct MPSAutocast {
device: Device,
enabled: bool,
mixed_precision: MPSMixedPrecision,
fp16_ops: HashMap<String, bool>,
}
impl MPSAutocast {
pub fn new(device: &Device, enabled: bool) -> Self {
let mut fp16_ops = HashMap::new();
fp16_ops.insert("conv2d".to_string(), true);
fp16_ops.insert("linear".to_string(), true);
fp16_ops.insert("matmul".to_string(), true);
fp16_ops.insert("bmm".to_string(), true);
fp16_ops.insert("addmm".to_string(), true);
fp16_ops.insert("softmax".to_string(), false);
fp16_ops.insert("log_softmax".to_string(), false);
fp16_ops.insert("cross_entropy".to_string(), false);
fp16_ops.insert("mse_loss".to_string(), false);
fp16_ops.insert("layer_norm".to_string(), false);
fp16_ops.insert("batch_norm".to_string(), false);
Self {
device: device.clone(),
enabled,
mixed_precision: MPSMixedPrecision::new(device),
fp16_ops,
}
}
pub fn should_use_fp16(&self, op_name: &str) -> bool {
if !self.enabled {
return false;
}
self.fp16_ops.get(op_name).copied().unwrap_or(false)
}
pub fn autocast_inputs(
&self,
command_buffer: &CommandBuffer,
op_name: &str,
inputs: &[MetalBuffer],
) -> Result<Vec<MetalBuffer>> {
let mut converted_inputs = Vec::new();
if self.should_use_fp16(op_name) {
for input in inputs {
let fp16_input = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F16,
&crate::metal::device::MetalDevice::new()?,
)?;
self.mixed_precision
.to_half_precision(command_buffer, input, &fp16_input)?;
converted_inputs.push(fp16_input);
}
} else {
for input in inputs {
if input.dtype() == torsh_core::DType::F16 {
let fp32_input = MetalBuffer::zeros(
input.shape(),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?;
self.mixed_precision
.to_full_precision(command_buffer, input, &fp32_input)?;
converted_inputs.push(fp32_input);
} else {
converted_inputs.push(input.clone());
}
}
}
Ok(converted_inputs)
}
pub fn mixed_precision(&mut self) -> &mut MPSMixedPrecision {
&mut self.mixed_precision
}
}
pub struct MPSGradScaler {
mixed_precision: MPSMixedPrecision,
update_frequency: usize,
update_counter: usize,
}
impl MPSGradScaler {
pub fn new(device: &Device, initial_scale: f32, growth_factor: f32) -> Self {
let mut mixed_precision = MPSMixedPrecision::new(device);
mixed_precision.loss_scaling = initial_scale;
mixed_precision.loss_scale_factor = growth_factor;
Self {
mixed_precision,
update_frequency: 2000,
update_counter: 0,
}
}
pub fn scale(
&self,
command_buffer: &CommandBuffer,
loss: &MetalBuffer,
scaled_loss: &MetalBuffer,
) -> Result<()> {
self.mixed_precision
.scale_loss(command_buffer, loss, scaled_loss)
}
pub fn step(
&mut self,
command_buffer: &CommandBuffer,
gradients: &[MetalBuffer],
unscaled_gradients: &[MetalBuffer],
) -> Result<bool> {
let should_update = self.mixed_precision.unscale_gradients(
command_buffer,
gradients,
unscaled_gradients,
)?;
self.update_counter += 1;
if self.update_counter >= self.update_frequency {
self.update_counter = 0;
}
Ok(should_update)
}
pub fn get_scale(&self) -> f32 {
self.mixed_precision.get_loss_scale()
}
pub fn found_inf(&self) -> bool {
self.mixed_precision.found_inf
}
}
pub mod utils {
use super::*;
pub fn create_amp_config() -> AMPConfig {
AMPConfig {
enabled: true,
opt_level: OptLevel::O1,
loss_scale: Some(128.0),
max_loss_scale: 65536.0,
min_loss_scale: 1.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
}
}
pub fn supports_efficient_fp16(_device: &Device) -> bool {
true
}
pub fn estimate_memory_savings(_model_params: usize) -> f32 {
0.35
}
pub fn estimate_performance_improvement(device: &Device) -> f32 {
if supports_efficient_fp16(device) {
1.7 } else {
1.1 }
}
}
#[derive(Debug, Clone)]
pub struct AMPConfig {
pub enabled: bool,
pub opt_level: OptLevel,
pub loss_scale: Option<f32>,
pub max_loss_scale: f32,
pub min_loss_scale: f32,
pub growth_factor: f32,
pub backoff_factor: f32,
pub growth_interval: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum OptLevel {
O0,
O1,
O2,
O3,
}
impl Default for AMPConfig {
fn default() -> Self {
utils::create_amp_config()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mixed_precision_creation() {
assert!(true);
}
#[test]
fn test_loss_scale_update() {
let Some(device) = metal::Device::system_default() else {
return;
};
let mut mp = MPSMixedPrecision::new(&device);
let initial_scale = mp.get_loss_scale();
mp.consecutive_unskipped = 2000;
mp.update_loss_scale();
assert!(mp.get_loss_scale() > initial_scale);
}
#[test]
fn test_autocast_op_detection() {
let Some(device) = metal::Device::system_default() else {
return;
};
let autocast = MPSAutocast::new(&device, true);
assert!(autocast.should_use_fp16("conv2d"));
assert!(autocast.should_use_fp16("linear"));
assert!(!autocast.should_use_fp16("softmax"));
assert!(!autocast.should_use_fp16("layer_norm"));
}
#[test]
fn test_amp_config() {
let config = AMPConfig::default();
assert!(config.enabled);
assert!(matches!(config.opt_level, OptLevel::O1));
assert!(config.growth_factor > 1.0);
}
}