burn_dragon_train 0.5.0

Training utilities for burn_dragon
Documentation
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum OptimizerKind {
    #[default]
    Adamw,
    BitnetPublished,
    MuonHybridExp,
}

#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum OptimizerScheduleMode {
    #[default]
    BdhReference,
    BitnetB158Reference,
    Hybrid,
}

#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum MuonAdjustLrFn {
    #[default]
    MatchRmsAdamw,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(default)]
pub struct MuonHybridConfig {
    pub enabled: bool,
    pub momentum: f32,
    pub nesterov: bool,
    pub ns_steps: usize,
    pub adjust_lr_fn: MuonAdjustLrFn,
    pub split_decoder_heads: bool,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub target_modules: Option<Vec<String>>,
}

impl Default for MuonHybridConfig {
    fn default() -> Self {
        Self {
            enabled: false,
            momentum: 0.95,
            nesterov: true,
            ns_steps: 5,
            adjust_lr_fn: MuonAdjustLrFn::default(),
            split_decoder_heads: true,
            target_modules: None,
        }
    }
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct OptimizerConfig {
    #[serde(default)]
    pub name: OptimizerKind,
    pub learning_rate: f64,
    pub weight_decay: f32,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub weight_decay_final: Option<f32>,
    #[serde(default)]
    pub lr_schedule: Option<LearningRateScheduleConfig>,
    #[serde(default)]
    pub schedule_mode: OptimizerScheduleMode,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub grad_clip_norm: Option<f32>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub grad_clip_value: Option<f32>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub muon: Option<MuonHybridConfig>,
}

impl OptimizerConfig {
    pub fn validate(&self) -> Result<()> {
        if self.learning_rate <= 0.0 {
            return Err(anyhow!("optimizer.learning_rate must be > 0"));
        }
        if self.weight_decay < 0.0 {
            return Err(anyhow!("optimizer.weight_decay must be >= 0"));
        }
        if let Some(weight_decay_final) = self.weight_decay_final
            && weight_decay_final < 0.0
        {
            return Err(anyhow!("optimizer.weight_decay_final must be >= 0"));
        }
        if let Some(clip) = self.grad_clip_norm
            && clip <= 0.0
        {
            return Err(anyhow!("optimizer.grad_clip_norm must be > 0"));
        }
        if let Some(clip) = self.grad_clip_value
            && clip <= 0.0
        {
            return Err(anyhow!("optimizer.grad_clip_value must be > 0"));
        }
        if self.grad_clip_norm.is_some() && self.grad_clip_value.is_some() {
            return Err(anyhow!(
                "optimizer.grad_clip_norm and optimizer.grad_clip_value are mutually exclusive"
            ));
        }
        if matches!(self.name, OptimizerKind::MuonHybridExp) {
            let muon = self.muon.as_ref().ok_or_else(|| {
                anyhow!("optimizer.muon must be set when optimizer.name = \"muon_hybrid_exp\"")
            })?;
            if !muon.enabled {
                return Err(anyhow!(
                    "optimizer.muon.enabled must be true when optimizer.name = \"muon_hybrid_exp\""
                ));
            }
            if muon.momentum <= 0.0 || muon.momentum >= 1.0 {
                return Err(anyhow!("optimizer.muon.momentum must be in (0, 1)"));
            }
            if muon.ns_steps == 0 {
                return Err(anyhow!("optimizer.muon.ns_steps must be > 0"));
            }
            if let Some(target_modules) = muon.target_modules.as_ref()
                && target_modules.is_empty()
            {
                return Err(anyhow!(
                    "optimizer.muon.target_modules must be non-empty when set"
                ));
            }
        }
        Ok(())
    }
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum LearningRateScheduleConfig {
    Constant {
        #[serde(default)]
        initial_lr: Option<f64>,
    },
    Cosine {
        #[serde(default)]
        initial_lr: Option<f64>,
        #[serde(default)]
        min_lr: Option<f64>,
        #[serde(default)]
        warmup_steps: Option<usize>,
        #[serde(default)]
        num_iters: Option<usize>,
    },
    Linear {
        #[serde(default)]
        initial_lr: Option<f64>,
        final_lr: f64,
        #[serde(default)]
        num_iters: Option<usize>,
    },
    Exponential {
        #[serde(default)]
        initial_lr: Option<f64>,
        gamma: f64,
    },
    Step {
        #[serde(default)]
        initial_lr: Option<f64>,
        #[serde(default = "default_step_gamma")]
        gamma: f64,
        #[serde(default)]
        step_size: Option<usize>,
    },
    Noam {
        #[serde(default)]
        initial_lr: Option<f64>,
        #[serde(default)]
        warmup_steps: Option<usize>,
        #[serde(default)]
        model_size: Option<usize>,
    },
}

fn default_step_gamma() -> f64 {
    0.1
}

#[cfg(test)]
mod tests {
    use super::*;

    fn base_optimizer() -> OptimizerConfig {
        OptimizerConfig {
            name: OptimizerKind::default(),
            learning_rate: 1.0e-3,
            weight_decay: 0.0,
            weight_decay_final: None,
            lr_schedule: None,
            schedule_mode: OptimizerScheduleMode::default(),
            grad_clip_norm: None,
            grad_clip_value: None,
            muon: None,
        }
    }

    #[test]
    fn muon_hybrid_requires_enabled_muon_config() {
        let config = OptimizerConfig {
            name: OptimizerKind::MuonHybridExp,
            muon: Some(MuonHybridConfig {
                enabled: false,
                ..Default::default()
            }),
            ..base_optimizer()
        };
        let err = config.validate().expect_err("expected validation failure");
        assert!(
            err.to_string().contains("optimizer.muon.enabled"),
            "unexpected error: {err}"
        );
    }

    #[test]
    fn weight_decay_final_must_be_non_negative() {
        let config = OptimizerConfig {
            weight_decay_final: Some(-0.1),
            ..base_optimizer()
        };
        let err = config.validate().expect_err("expected validation failure");
        assert!(
            err.to_string().contains("weight_decay_final"),
            "unexpected error: {err}"
        );
    }
}