use crate::error::Result;
use crate::optimizer::grad_scaler::GradScaler;
use numr::dtype::DType;
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub learning_rate: f64,
pub weight_decay: f64,
pub max_grad_norm: Option<f64>,
pub grad_accum_steps: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
weight_decay: 0.01,
max_grad_norm: Some(1.0),
grad_accum_steps: 1,
}
}
}
impl TrainingConfig {
pub fn with_lr(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
pub fn with_weight_decay(mut self, wd: f64) -> Self {
self.weight_decay = wd;
self
}
pub fn with_max_grad_norm(mut self, norm: Option<f64>) -> Self {
self.max_grad_norm = norm;
self
}
pub fn with_grad_accum_steps(mut self, steps: usize) -> Self {
self.grad_accum_steps = steps;
self
}
}
#[derive(Debug, Clone)]
pub struct TrainingMetrics {
pub step: u64,
pub loss: f64,
pub grad_norm: Option<f64>,
pub lr: f64,
}
#[derive(Debug, Clone)]
pub struct MixedPrecisionConfig {
pub compute_dtype: DType,
pub master_dtype: DType,
pub loss_scale: LossScaleStrategy,
}
impl MixedPrecisionConfig {
pub fn bf16() -> Self {
Self {
compute_dtype: DType::BF16,
master_dtype: DType::F32,
loss_scale: LossScaleStrategy::None,
}
}
pub fn fp16() -> Self {
Self {
compute_dtype: DType::F16,
master_dtype: DType::F32,
loss_scale: LossScaleStrategy::Dynamic {
initial_scale: 65536.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
},
}
}
}
#[derive(Debug, Clone)]
pub enum LossScaleStrategy {
None,
Fixed(f64),
Dynamic {
initial_scale: f64,
growth_factor: f64,
backoff_factor: f64,
growth_interval: u64,
},
}
impl LossScaleStrategy {
pub(crate) fn to_grad_scaler(&self) -> Result<Option<GradScaler>> {
match self {
LossScaleStrategy::None => Ok(None),
LossScaleStrategy::Fixed(scale) => {
Ok(Some(GradScaler::new(*scale, 2.0, 0.5, u64::MAX)?))
}
LossScaleStrategy::Dynamic {
initial_scale,
growth_factor,
backoff_factor,
growth_interval,
} => Ok(Some(GradScaler::new(
*initial_scale,
*growth_factor,
*backoff_factor,
*growth_interval,
)?)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = TrainingConfig::default();
assert_eq!(config.learning_rate, 1e-3);
assert_eq!(config.weight_decay, 0.01);
assert_eq!(config.max_grad_norm, Some(1.0));
assert_eq!(config.grad_accum_steps, 1);
}
#[test]
fn test_builder() {
let config = TrainingConfig::default()
.with_lr(0.01)
.with_weight_decay(0.1)
.with_max_grad_norm(None)
.with_grad_accum_steps(4);
assert_eq!(config.learning_rate, 0.01);
assert_eq!(config.weight_decay, 0.1);
assert_eq!(config.max_grad_norm, None);
assert_eq!(config.grad_accum_steps, 4);
}
}