use candle_core::{DType, Tensor};
use crate::error::{Result, UnslothError};
use crate::memory::CheckpointConfig;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrecisionMode {
Full,
Half,
BFloat16,
}
impl PrecisionMode {
#[must_use]
pub fn to_dtype(&self) -> DType {
match self {
Self::Full => DType::F32,
Self::Half => DType::F16,
Self::BFloat16 => DType::BF16,
}
}
pub fn from_dtype(dtype: DType) -> Result<Self> {
match dtype {
DType::F32 => Ok(Self::Full),
DType::F16 => Ok(Self::Half),
DType::BF16 => Ok(Self::BFloat16),
_ => Err(UnslothError::InvalidConfig(format!(
"Unsupported dtype for mixed precision: {dtype:?}"
))),
}
}
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub compute_precision: PrecisionMode,
pub master_precision: PrecisionMode,
pub loss_scale: f32,
pub dynamic_loss_scale: bool,
pub min_loss_scale: f32,
pub max_loss_scale: f32,
pub scale_growth_factor: f32,
pub scale_backoff_factor: f32,
pub scale_growth_interval: usize,
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
Self {
compute_precision: PrecisionMode::Half,
master_precision: PrecisionMode::Full,
loss_scale: 65536.0, dynamic_loss_scale: true,
min_loss_scale: 1.0,
max_loss_scale: 2_147_483_648.0, scale_growth_factor: 2.0,
scale_backoff_factor: 0.5,
scale_growth_interval: 2000,
}
}
}
impl MixedPrecisionConfig {
#[must_use]
pub fn new(compute_precision: PrecisionMode) -> Self {
Self {
compute_precision,
..Default::default()
}
}
#[must_use]
pub fn fp16() -> Self {
Self::new(PrecisionMode::Half)
}
#[must_use]
pub fn bf16() -> Self {
Self::new(PrecisionMode::BFloat16)
}
#[must_use]
pub fn fp32() -> Self {
Self {
compute_precision: PrecisionMode::Full,
master_precision: PrecisionMode::Full,
dynamic_loss_scale: false,
loss_scale: 1.0,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub batch_size: usize,
pub max_seq_len: usize,
pub gradient_accumulation_steps: usize,
pub mixed_precision: Option<MixedPrecisionConfig>,
pub checkpoint_config: CheckpointConfig,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
batch_size: 4,
max_seq_len: 2048,
gradient_accumulation_steps: 4,
mixed_precision: Some(MixedPrecisionConfig::default()),
checkpoint_config: CheckpointConfig::default(),
}
}
}
pub fn convert_precision(tensor: &Tensor, precision: PrecisionMode) -> Result<Tensor> {
let target_dtype = precision.to_dtype();
if tensor.dtype() == target_dtype {
Ok(tensor.clone())
} else {
Ok(tensor.to_dtype(target_dtype)?)
}
}
pub fn scale_loss(loss: &Tensor, config: &MixedPrecisionConfig) -> Result<Tensor> {
if (config.loss_scale - 1.0).abs() < f32::EPSILON {
Ok(loss.clone())
} else {
Ok((loss * f64::from(config.loss_scale))?)
}
}
pub fn unscale_gradients(
gradients: &[Tensor],
config: &MixedPrecisionConfig,
) -> Result<Vec<Tensor>> {
if (config.loss_scale - 1.0).abs() < f32::EPSILON {
Ok(gradients.to_vec())
} else {
let scale = 1.0 / f64::from(config.loss_scale);
gradients
.iter()
.map(|g| (g * scale).map_err(Into::into))
.collect()
}
}
pub fn has_inf_or_nan(gradients: &[Tensor]) -> Result<bool> {
for grad in gradients {
let grad_f32 = grad.to_dtype(DType::F32)?;
let values: Vec<f32> = grad_f32.flatten_all()?.to_vec1()?;
for &val in &values {
if val.is_nan() || val.is_infinite() {
return Ok(true);
}
}
}
Ok(false)
}
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
pub fn update_loss_scale(
config: &mut MixedPrecisionConfig,
has_overflow: bool,
steps_since_overflow: usize,
) -> f32 {
if !config.dynamic_loss_scale {
return config.loss_scale;
}
if has_overflow {
config.loss_scale =
(config.loss_scale * config.scale_backoff_factor).max(config.min_loss_scale);
} else if steps_since_overflow >= config.scale_growth_interval {
config.loss_scale =
(config.loss_scale * config.scale_growth_factor).min(config.max_loss_scale);
}
config.loss_scale
}
pub fn compute_gradient_checkpointed<F>(
_input: &Tensor,
_forward_fn: F,
_config: &CheckpointConfig,
) -> Result<Tensor>
where
F: Fn(&Tensor) -> Result<Tensor>,
{
Err(UnslothError::InvalidConfig(
"Gradient checkpointing is not yet implemented. This feature is planned for a future release.".to_string()
))
}
pub fn scale_gradients(gradients: &[Tensor], scale: f32) -> Result<Vec<Tensor>> {
gradients
.iter()
.map(|g| (g * f64::from(scale)).map_err(Into::into))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_training_config_default() {
let config = TrainingConfig::default();
assert_eq!(config.batch_size, 4);
assert!(config.mixed_precision.is_some());
}
#[test]
fn test_precision_mode_to_dtype() {
assert_eq!(PrecisionMode::Full.to_dtype(), DType::F32);
assert_eq!(PrecisionMode::Half.to_dtype(), DType::F16);
assert_eq!(PrecisionMode::BFloat16.to_dtype(), DType::BF16);
}
#[test]
fn test_precision_mode_from_dtype() {
assert_eq!(
PrecisionMode::from_dtype(DType::F32).unwrap(),
PrecisionMode::Full
);
assert_eq!(
PrecisionMode::from_dtype(DType::F16).unwrap(),
PrecisionMode::Half
);
assert_eq!(
PrecisionMode::from_dtype(DType::BF16).unwrap(),
PrecisionMode::BFloat16
);
assert!(PrecisionMode::from_dtype(DType::U8).is_err());
}
#[test]
fn test_mixed_precision_config_defaults() {
let config = MixedPrecisionConfig::default();
assert_eq!(config.compute_precision, PrecisionMode::Half);
assert_eq!(config.master_precision, PrecisionMode::Full);
assert_eq!(config.loss_scale, 65536.0);
assert!(config.dynamic_loss_scale);
}
#[test]
fn test_mixed_precision_config_fp16() {
let config = MixedPrecisionConfig::fp16();
assert_eq!(config.compute_precision, PrecisionMode::Half);
assert_eq!(config.master_precision, PrecisionMode::Full);
}
#[test]
fn test_mixed_precision_config_bf16() {
let config = MixedPrecisionConfig::bf16();
assert_eq!(config.compute_precision, PrecisionMode::BFloat16);
}
#[test]
fn test_mixed_precision_config_fp32() {
let config = MixedPrecisionConfig::fp32();
assert_eq!(config.compute_precision, PrecisionMode::Full);
assert_eq!(config.master_precision, PrecisionMode::Full);
assert!(!config.dynamic_loss_scale);
assert_eq!(config.loss_scale, 1.0);
}
#[test]
fn test_convert_precision() {
let device = Device::Cpu;
let tensor = Tensor::ones((2, 3), DType::F32, &device).unwrap();
let fp16 = convert_precision(&tensor, PrecisionMode::Half).unwrap();
assert_eq!(fp16.dtype(), DType::F16);
let bf16 = convert_precision(&tensor, PrecisionMode::BFloat16).unwrap();
assert_eq!(bf16.dtype(), DType::BF16);
let same = convert_precision(&tensor, PrecisionMode::Full).unwrap();
assert_eq!(same.dtype(), DType::F32);
}
#[test]
fn test_scale_loss() {
let device = Device::Cpu;
let loss = Tensor::full(2.0f32, (), &device).unwrap();
let mut config = MixedPrecisionConfig::default();
config.loss_scale = 4.0;
let scaled = scale_loss(&loss, &config).unwrap();
let value: f32 = scaled.to_scalar().unwrap();
assert!((value - 8.0).abs() < 1e-5);
}
#[test]
fn test_unscale_gradients() {
let device = Device::Cpu;
let grad1 = Tensor::full(8.0f32, (2, 2), &device).unwrap();
let grad2 = Tensor::full(16.0f32, (2, 2), &device).unwrap();
let gradients = vec![grad1, grad2];
let mut config = MixedPrecisionConfig::default();
config.loss_scale = 4.0;
let unscaled = unscale_gradients(&gradients, &config).unwrap();
let vals1: Vec<f32> = unscaled[0].flatten_all().unwrap().to_vec1().unwrap();
for val in vals1 {
assert!((val - 2.0).abs() < 1e-5);
}
let vals2: Vec<f32> = unscaled[1].flatten_all().unwrap().to_vec1().unwrap();
for val in vals2 {
assert!((val - 4.0).abs() < 1e-5);
}
}
#[test]
fn test_has_inf_or_nan() {
let device = Device::Cpu;
let grad1 = Tensor::ones((2, 2), DType::F32, &device).unwrap();
let grad2 = Tensor::full(2.0f32, (2, 2), &device).unwrap();
assert!(!has_inf_or_nan(&[grad1, grad2]).unwrap());
let nan_grad = Tensor::full(f32::NAN, (2, 2), &device).unwrap();
assert!(has_inf_or_nan(&[nan_grad]).unwrap());
let inf_grad = Tensor::full(f32::INFINITY, (2, 2), &device).unwrap();
assert!(has_inf_or_nan(&[inf_grad]).unwrap());
}
#[test]
fn test_update_loss_scale_on_overflow() {
let mut config = MixedPrecisionConfig {
loss_scale: 1000.0,
scale_backoff_factor: 0.5,
..Default::default()
};
let new_scale = update_loss_scale(&mut config, true, 0);
assert_eq!(new_scale, 500.0);
assert_eq!(config.loss_scale, 500.0);
}
#[test]
fn test_update_loss_scale_growth() {
let mut config = MixedPrecisionConfig {
loss_scale: 100.0,
scale_growth_factor: 2.0,
scale_growth_interval: 100,
..Default::default()
};
let new_scale = update_loss_scale(&mut config, false, 100);
assert_eq!(new_scale, 200.0);
assert_eq!(config.loss_scale, 200.0);
}
#[test]
fn test_update_loss_scale_no_change() {
let mut config = MixedPrecisionConfig::default();
config.loss_scale = 100.0;
let new_scale = update_loss_scale(&mut config, false, 10);
assert_eq!(new_scale, 100.0);
}
#[test]
fn test_update_loss_scale_bounds() {
let mut config = MixedPrecisionConfig {
min_loss_scale: 1.0,
max_loss_scale: 1000.0,
loss_scale: 2.0,
scale_backoff_factor: 0.5,
..Default::default()
};
update_loss_scale(&mut config, true, 0);
assert!((config.loss_scale - 1.0).abs() < f32::EPSILON);
config.loss_scale = 600.0;
config.scale_growth_factor = 2.0;
config.scale_growth_interval = 10;
update_loss_scale(&mut config, false, 10);
assert!((config.loss_scale - 1000.0).abs() < f32::EPSILON); }
#[test]
fn test_scale_gradients() {
let device = Device::Cpu;
let grad1 = Tensor::ones((2, 3), DType::F32, &device).unwrap();
let grad2 = Tensor::full(2.0f32, (2, 3), &device).unwrap();
let gradients = vec![grad1, grad2];
let scale = 0.5;
let scaled = scale_gradients(&gradients, scale).unwrap();
let vals1: Vec<f32> = scaled[0].flatten_all().unwrap().to_vec1().unwrap();
for val in vals1 {
assert!((val - 0.5).abs() < 1e-5);
}
let vals2: Vec<f32> = scaled[1].flatten_all().unwrap().to_vec1().unwrap();
for val in vals2 {
assert!((val - 1.0).abs() < 1e-5);
}
}
}