use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::time::Instant;
use rand::SeedableRng;
use rand::seq::SliceRandom;
use serde::Serialize;
use crate::checkpoint::{Checkpoint, ParameterSnapshot, save_checkpoint, snapshot_parameter};
use crate::dataset_bridge::{LocalDataset, LocalSample};
use crate::domain::DomainId;
use crate::metrics::grad_norm;
use crate::model::layer::RouterCache;
use crate::model::{Layer, Model, Parameter};
use crate::model_arch::{ModelArch, ModelKind, RouterKind, arch_fingerprint};
use crate::moe_model::diagnose_ref::{
build_from_model as build_fp32_ref, forward_backward as fp32_forward_backward, grad_norm_of,
param_names as fp32_param_names, simulated_fp16_forward_backward as fp16_sim_forward_backward,
};
use crate::moe_model::{MoEModel, MoESize, N_EXPERTS};
use crate::object::{Dim, Shape, Tensor};
use crate::synth_data::regression::RegressionSample;
use crate::synth_data::{QualitySample, make_quality_decision_dataset, make_regression_dataset};
use crate::{Error, Result};
#[derive(Debug, Clone)]
struct TrainingSample {
input: Vec<f32>,
target: Vec<f32>,
}
impl From<QualitySample> for TrainingSample {
fn from(s: QualitySample) -> Self {
Self {
input: s.input,
target: s.target,
}
}
}
impl From<RegressionSample> for TrainingSample {
fn from(s: RegressionSample) -> Self {
Self {
input: s.0,
target: s.1,
}
}
}
impl From<LocalSample> for TrainingSample {
fn from(s: LocalSample) -> Self {
Self {
input: s.features,
target: s.labels,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum Optimizer {
Adamw,
Sgd,
}
impl Default for Optimizer {
fn default() -> Self {
Optimizer::Adamw
}
}
#[derive(Debug, Clone)]
pub struct TrainConfig {
pub model_kind: ModelKind,
pub router_kind: RouterKind,
pub size: MoESize,
pub steps: u32,
pub batch_size: usize,
pub lr: f32,
pub weight_decay: f32,
pub grad_clip: f32,
pub seed: u32,
pub checkpoint_dir: PathBuf,
pub metrics_path: PathBuf,
pub arch_path: PathBuf,
pub quiet: bool,
pub dataset_spec: String,
pub use_hip: bool,
pub dry_run: bool,
pub optimizer: Optimizer,
pub momentum: f32,
pub diagnose: bool,
pub router_lr_scale: f32,
pub warmup_steps: u32,
pub min_lr_ratio: f32,
}
impl TrainConfig {
pub fn lr_at(&self, step: u32) -> f32 {
let base = self.lr;
let min_lr = base * self.min_lr_ratio;
if step < self.warmup_steps {
if self.warmup_steps == 0 {
return base;
}
return base * (step as f32 / self.warmup_steps as f32);
}
if self.steps <= self.warmup_steps {
return base;
}
let decay_span = (self.steps - self.warmup_steps) as f32;
let progress = (step - self.warmup_steps) as f32 / decay_span;
let progress = progress.clamp(0.0, 1.0);
let cosine = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
min_lr + (base - min_lr) * cosine
}
}
impl TrainConfig {
pub fn tiny_default() -> Self {
Self::tiny_moe_sheaf_padic()
}
pub fn tiny_moe_sheaf_padic() -> Self {
let timestamp = unix_timestamp();
let checkpoint_dir = PathBuf::from(format!("./var/training/{timestamp}"));
Self {
model_kind: ModelKind::MoE,
router_kind: RouterKind::SheafPadic,
size: MoESize::Tiny,
steps: 50,
batch_size: 128,
lr: 0.01,
weight_decay: 0.01,
grad_clip: 0.5,
seed: 42,
metrics_path: checkpoint_dir.join("metrics.jsonl"),
arch_path: checkpoint_dir.join("arch.json"),
checkpoint_dir,
quiet: false,
dataset_spec: "synth:quality:512".to_string(),
use_hip: false,
dry_run: false,
optimizer: Optimizer::Adamw,
momentum: 0.9,
router_lr_scale: 0.1,
diagnose: false,
warmup_steps: 5,
min_lr_ratio: 0.1,
}
}
pub fn tiny_softmax_moe() -> Self {
let mut cfg = Self::tiny_moe_sheaf_padic();
cfg.router_kind = RouterKind::SoftmaxOnly;
cfg
}
pub fn tiny_dense() -> Self {
let mut cfg = Self::tiny_moe_sheaf_padic();
cfg.model_kind = ModelKind::Dense;
cfg.router_lr_scale = 1.0;
cfg
}
}
#[derive(Debug, Clone, Serialize)]
pub struct StepMetrics {
pub step: u32,
pub loss: f32,
pub final_loss: f32,
pub grad_norm: f32,
pub lr: f32,
pub elapsed_ms: f64,
pub model_size: String,
pub router_entropy_mean: f32,
pub per_expert_share: [f32; N_EXPERTS],
}
#[derive(Debug, Clone)]
pub struct TrainSummary {
pub model_size: String,
pub total_params: usize,
pub final_loss: f32,
pub steps_run: u32,
pub time_elapsed_sec: f64,
pub throughput_steps_per_sec: f64,
pub checkpoint_path: PathBuf,
pub arch_path: PathBuf,
pub arch_fingerprint: String,
pub last_router_entropy_mean: f32,
}
pub fn default_steps(size: MoESize) -> u32 {
match size {
MoESize::Nano => 20,
MoESize::Tiny => 50,
_ => 100,
}
}
pub fn default_use_hip() -> bool {
false
}
pub fn run_training(cfg: &TrainConfig) -> Result<TrainSummary> {
if cfg.steps == 0 {
return Err(Error::backend("run_training: steps must be >= 1"));
}
if cfg.batch_size == 0 {
return Err(Error::backend("run_training: batch_size must be >= 1"));
}
fs::create_dir_all(&cfg.checkpoint_dir).map_err(|e| {
Error::backend(format!(
"create_dir_all {}: {e}",
cfg.checkpoint_dir.display()
))
})?;
let (total_params, arch, model_opt) = if cfg.dry_run {
let total = cfg.size.param_count();
let arch = ModelArch::from_size(cfg.size, cfg.seed as u64);
(total, arch, None)
} else {
let model = MoEModel::new(cfg.size, cfg.seed as u64);
let total = model.scalar_param_count();
let arch = ModelArch::from_moe_model(&model);
(total, arch, Some(model))
};
let mut dataset = build_dataset(&cfg.dataset_spec, cfg.seed)?;
if dataset.is_empty() {
return Err(Error::backend("run_training: dataset is empty"));
}
let fingerprint = arch_fingerprint(&arch);
let param_count_value = total_params as u64;
write_arch_with_fingerprint(&cfg.arch_path, &arch, &fingerprint, param_count_value)?;
let mut metrics_file = fs::File::create(&cfg.metrics_path).map_err(|e| {
Error::backend(format!(
"create metrics file {}: {e}",
cfg.metrics_path.display()
))
})?;
let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed as u64);
let model_size_label = cfg.size.name().to_string();
let started = Instant::now();
let mut last_loss = 0.0f32;
let mut last_router_entropy = f32::NAN;
let checkpoint_path = cfg.checkpoint_dir.join("checkpoint.tkp1");
if cfg.dry_run {
let dry_share = [1.0 / N_EXPERTS as f32; N_EXPERTS];
let dry_entropy = (N_EXPERTS as f32).ln();
for step in 0..cfg.steps {
let sp = step as f32;
let loss = 1.0f32 / (1.0f32 + sp);
let grad_n = 1.0f32 / (1.0f32 + sp);
let elapsed_ms = started.elapsed().as_secs_f64() * 1000.0;
last_loss = loss;
let row = StepMetrics {
step,
loss,
final_loss: loss,
grad_norm: grad_n,
lr: cfg.lr_at(step),
elapsed_ms,
model_size: model_size_label.clone(),
router_entropy_mean: dry_entropy,
per_expert_share: dry_share,
};
let line = serde_json::to_string(&row)
.map_err(|e| Error::backend(format!("serialize metrics row: {e}")))?;
writeln!(metrics_file, "{line}")
.map_err(|e| Error::backend(format!("write metrics row: {e}")))?;
metrics_file
.flush()
.map_err(|e| Error::backend(format!("flush metrics row: {e}")))?;
if !cfg.quiet {
eprintln!(
"[train_quality_moe] step {:4} loss={:.6} grad_norm={:.4} router_entropy={:.4} elapsed_ms={:.1}",
step, loss, grad_n, dry_entropy, elapsed_ms
);
}
}
let optimizer_name = match cfg.optimizer {
Optimizer::Adamw => "adamw",
Optimizer::Sgd => "sgd",
};
let ckpt = Checkpoint {
step: cfg.steps,
params: Vec::new(),
config: format!(
"{{\"size\":\"{}\",\"seed\":{},\"steps\":{},\"dry_run\":true,\"optimizer\":\"{}\",\"momentum\":{}}}",
cfg.size.name(),
cfg.seed,
cfg.steps,
optimizer_name,
cfg.momentum,
),
};
save_checkpoint(&checkpoint_path, &ckpt)?;
} else {
let mut model = model_opt
.expect("run_training: model_opt must be Some when dry_run is false (checked above)");
for step in 0..cfg.steps {
if step as usize % dataset.len().max(1) == 0 && step > 0 {
dataset.shuffle(&mut rng);
}
let batch = dataset.batch(cfg.batch_size, step);
if batch.is_empty() {
return Err(Error::backend(format!(
"run_training: empty batch at step {step}"
)));
}
let (xb, yb) = stack_batch(&batch);
let step_started = Instant::now();
let (loss, grad_n, router_entropy, per_expert_share) =
train_step_cpu(&mut model, &xb, &yb, cfg, cfg.lr_at(step))?;
let step_ms = step_started.elapsed().as_secs_f64() * 1000.0;
let elapsed_ms = started.elapsed().as_secs_f64() * 1000.0;
last_loss = loss;
last_router_entropy = router_entropy;
let row = StepMetrics {
step,
loss,
final_loss: loss,
grad_norm: grad_n,
lr: cfg.lr_at(step),
elapsed_ms,
model_size: model_size_label.clone(),
router_entropy_mean: router_entropy,
per_expert_share,
};
let line = serde_json::to_string(&row)
.map_err(|e| Error::backend(format!("serialize metrics row: {e}")))?;
writeln!(metrics_file, "{line}")
.map_err(|e| Error::backend(format!("write metrics row: {e}")))?;
metrics_file
.flush()
.map_err(|e| Error::backend(format!("flush metrics row: {e}")))?;
if !cfg.quiet {
eprintln!(
"[train_quality_moe] step {:4} loss={:.6} grad_norm={:.4} router_entropy={:.4} step_ms={:.2} elapsed_ms={:.1}",
step, loss, grad_n, router_entropy, step_ms, elapsed_ms
);
}
}
let ckpt = build_checkpoint(&model, cfg.steps, cfg.seed as u64);
save_checkpoint(&checkpoint_path, &ckpt)?;
}
let total_elapsed = started.elapsed().as_secs_f64();
let throughput = if total_elapsed > 0.0 {
cfg.steps as f64 / total_elapsed
} else {
0.0
};
Ok(TrainSummary {
model_size: model_size_label,
total_params,
final_loss: last_loss,
steps_run: cfg.steps,
time_elapsed_sec: total_elapsed,
throughput_steps_per_sec: throughput,
checkpoint_path: cfg.checkpoint_dir.join("checkpoint.tkp1"),
arch_path: cfg.arch_path.clone(),
arch_fingerprint: fingerprint,
last_router_entropy_mean: last_router_entropy,
})
}
fn build_dataset(spec: &str, seed: u32) -> Result<Dataset> {
if let Some(rest) = spec.strip_prefix("synth:") {
if let Some(n_str) = rest.strip_prefix("quality:") {
let n: usize = n_str
.parse()
.map_err(|e| Error::backend(format!("parse synth:quality:N N: {e}")))?;
let samples: Vec<TrainingSample> = make_quality_decision_dataset(n, seed as u64)
.into_iter()
.map(TrainingSample::from)
.collect();
return Ok(Dataset { samples });
}
if let Some(n_str) = rest.strip_prefix("regression:") {
let n: usize = n_str
.parse()
.map_err(|e| Error::backend(format!("parse synth:regression:N N: {e}")))?;
let in_dim = crate::synth_data::QUALITY_INPUT_DIM;
let out_dim = crate::synth_data::QUALITY_OUTPUT_DIM;
let samples: Vec<TrainingSample> =
make_regression_dataset(n, in_dim, out_dim, seed as u64)
.into_iter()
.map(TrainingSample::from)
.collect();
return Ok(Dataset { samples });
}
return Err(Error::backend(format!(
"unknown synth dataset spec: {spec} (expected synth:quality:N or synth:regression:N)"
)));
}
if spec == "empty" {
return Ok(Dataset { samples: vec![] });
}
let path = Path::new(spec);
let decisions = path.join("quality_decisions.db");
let outcomes = path.join("quality_outcomes.db");
let local = LocalDataset::open_sqlite(&decisions, &outcomes).map_err(|e| {
Error::backend(format!(
"open sqlite dataset at {spec}: {e} (looking for quality_decisions.db and quality_outcomes.db)"
))
})?;
let samples: Vec<TrainingSample> = local
.as_slice()
.iter()
.cloned()
.map(TrainingSample::from)
.collect();
Ok(Dataset { samples })
}
#[derive(Debug, Clone)]
struct Dataset {
samples: Vec<TrainingSample>,
}
impl Dataset {
fn len(&self) -> usize {
self.samples.len()
}
fn is_empty(&self) -> bool {
self.samples.is_empty()
}
fn shuffle(&mut self, rng: &mut rand::rngs::StdRng) {
self.samples.shuffle(rng);
}
fn batch(&self, batch_size: usize, step: u32) -> Vec<TrainingSample> {
if self.samples.is_empty() {
return vec![];
}
let n = self.samples.len();
let start = (step as usize * batch_size) % n;
let mut out: Vec<TrainingSample> = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let idx = (start + i) % n;
out.push(self.samples[idx].clone());
}
out
}
}
fn stack_batch(batch: &[TrainingSample]) -> (Tensor<f32>, Tensor<f32>) {
let b = batch.len();
let in_dim = batch[0].input.len();
let out_dim = batch[0].target.len();
let mut xb: Vec<f32> = Vec::with_capacity(b * in_dim);
let mut yb: Vec<f32> = Vec::with_capacity(b * out_dim);
for s in batch {
assert_eq!(s.input.len(), in_dim, "stack_batch: input dim mismatch");
assert_eq!(s.target.len(), out_dim, "stack_batch: target dim mismatch");
xb.extend_from_slice(&s.input);
yb.extend_from_slice(&s.target);
}
let xt = Tensor::dense_cpu(
DomainId::new("f32"),
Shape::new(vec![Dim::Static(b), Dim::Static(in_dim)]),
xb,
);
let yt = Tensor::dense_cpu(
DomainId::new("f32"),
Shape::new(vec![Dim::Static(b), Dim::Static(out_dim)]),
yb,
);
(xt, yt)
}
pub fn train_step_cpu(
model: &mut MoEModel,
inputs: &Tensor<f32>,
targets: &Tensor<f32>,
cfg: &TrainConfig,
base_lr: f32,
) -> Result<(f32, f32, f32, [f32; N_EXPERTS])> {
let output = model.forward(inputs)?;
let (router_entropy_mean, per_expert_share) = compute_router_stats(&output.router_weights);
let logits = &output.logits;
let b = match &logits.meta.shape.dims[0] {
Dim::Static(v) => *v,
_ => {
return Err(Error::shape(
"train_step_cpu: logits batch dim must be static",
));
}
};
let out_dim = match &logits.meta.shape.dims[1] {
Dim::Static(v) => *v,
_ => {
return Err(Error::shape(
"train_step_cpu: logits out dim must be static",
));
}
};
if logits.data.len() != b * out_dim {
return Err(Error::shape(format!(
"train_step_cpu: logits data {} != batch*out_dim = {}*{}",
logits.data.len(),
b,
out_dim
)));
}
if targets.data.len() != b * out_dim {
return Err(Error::shape(format!(
"train_step_cpu: target data {} != batch*out_dim = {}*{}",
targets.data.len(),
b,
out_dim
)));
}
let n = b * out_dim;
let nf = n as f32;
let mut loss = 0.0f32;
let mut grad_logits = vec![0.0f32; n];
for i in 0..n {
let diff = logits.data[i] - targets.data[i];
loss += diff * diff;
grad_logits[i] = 2.0 * diff / nf;
}
loss /= nf;
let grad_output = Tensor::dense_cpu(
logits.meta.domain.clone(),
logits.meta.shape.clone(),
grad_logits,
);
let (_grad_input, param_grads) = model.backward(&grad_output)?;
let flat_grads = flatten_param_grads(¶m_grads);
let total_norm = grad_norm(&flat_grads);
let clip_scale = if cfg.grad_clip > 0.0 && total_norm > cfg.grad_clip {
cfg.grad_clip / total_norm
} else {
1.0
};
let mut param_grads = param_grads;
apply_clip_scale(&mut param_grads, clip_scale);
let mut params = all_parameters_mut(model);
if params.len() != param_grads.len() {
return Err(Error::backend(format!(
"train_step_cpu: param count {} != param_grads count {}",
params.len(),
param_grads.len()
)));
}
#[cfg(feature = "rocm-hip")]
{
if cfg.optimizer == Optimizer::Adamw
&& cfg.use_hip
&& hip_adamw_available()
&& !params.is_empty()
{
for param in params.iter_mut() {
if param.step == u32::MAX {
return Err(Error::backend(
"train_step_cpu: parameter step counter overflow",
));
}
param.step += 1;
}
let t_global = params[0].step as i32;
let router_n = 2.min(params.len());
let (router_params, expert_params) = params.split_at_mut(router_n);
let (router_grads, expert_grads) = param_grads.split_at(router_n);
let router_grad_refs: Vec<&Tensor<f32>> = router_grads.iter().collect();
let expert_grad_refs: Vec<&Tensor<f32>> = expert_grads.iter().collect();
if !router_params.is_empty() {
adamw_step_batched_binary_for_group(
router_params,
&router_grad_refs,
base_lr * cfg.router_lr_scale,
0.9,
0.999,
1e-8,
cfg.weight_decay,
t_global,
)?;
}
if !expert_params.is_empty() {
adamw_step_batched_binary_for_group(
expert_params,
&expert_grad_refs,
base_lr,
0.9,
0.999,
1e-8,
cfg.weight_decay,
t_global,
)?;
}
return Ok((loss, total_norm, router_entropy_mean, per_expert_share));
}
}
for (idx, (param, grad_t)) in params.iter_mut().zip(param_grads.iter()).enumerate() {
let is_router_param = idx < 2;
let eff_lr = if is_router_param {
base_lr * cfg.router_lr_scale
} else {
base_lr
};
match cfg.optimizer {
Optimizer::Sgd => {
param.sgd_momentum_step(grad_t, eff_lr, cfg.momentum)?;
}
Optimizer::Adamw => {
param.adamw_step(grad_t, eff_lr, 0.9, 0.999, 1e-8, cfg.weight_decay)?;
}
}
}
Ok((loss, total_norm, router_entropy_mean, per_expert_share))
}
#[cfg(feature = "rocm-hip")]
#[allow(clippy::too_many_arguments)]
fn adamw_step_batched_binary_for_group(
params: &mut [&mut crate::model::parameter::Parameter],
grads: &[&Tensor<f32>],
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
t: i32,
) -> Result<()> {
use crate::backend::f16_convert::{f16_to_f32, f32_to_f16};
#[allow(unused_imports)]
mod _hip_adamw {
pub use crate as tokitai_operator;
include!("backend/hip_adamw.rs");
}
use _hip_adamw::run_rocm_hip_adamw_step_all_binary;
let n_params = params.len();
if n_params == 0 {
return Ok(());
}
if n_params != grads.len() {
return Err(Error::backend(format!(
"adamw_step_batched_binary_for_group: params/grads count mismatch \
(params={n_params}, grads={})",
grads.len()
)));
}
let mut theta_slices: Vec<Vec<u16>> = Vec::with_capacity(n_params);
let mut m_slices: Vec<Vec<f32>> = Vec::with_capacity(n_params);
let mut v_slices: Vec<Vec<f32>> = Vec::with_capacity(n_params);
let mut grad_slices: Vec<Vec<u16>> = Vec::with_capacity(n_params);
for (param, grad) in params.iter_mut().zip(grads.iter()) {
let n = param.data.data.len();
if n == 0 {
theta_slices.push(Vec::new());
m_slices.push(Vec::new());
v_slices.push(Vec::new());
grad_slices.push(Vec::new());
continue;
}
if n != grad.data.len() || n != param.m.data.len() || n != param.v.data.len() {
return Err(Error::backend(format!(
"adamw_step_batched_binary_for_group: length mismatch \
(data={n}, m={}, v={}, grad={})",
param.m.data.len(),
param.v.data.len(),
grad.data.len()
)));
}
let mut theta_i: Vec<u16> = vec![0u16; n];
let mut m_i: Vec<f32> = vec![0.0f32; n];
let mut v_i: Vec<f32> = vec![0.0f32; n];
for j in 0..n {
theta_i[j] = f32_to_f16(param.data.data[j]);
m_i[j] = param.m.data[j];
v_i[j] = param.v.data[j];
}
let grad_i: Vec<u16> = grad.data.iter().map(|&x| f32_to_f16(x)).collect();
theta_slices.push(theta_i);
m_slices.push(m_i);
v_slices.push(v_i);
grad_slices.push(grad_i);
}
run_rocm_hip_adamw_step_all_binary(
&mut theta_slices,
&mut m_slices,
&mut v_slices,
&grad_slices,
lr,
beta1,
beta2,
eps,
weight_decay,
t,
)?;
for (i, param) in params.iter_mut().enumerate() {
let n = param.data.data.len();
if n == 0 {
continue;
}
for j in 0..n {
param.data.data[j] = f16_to_f32(theta_slices[i][j]);
param.m.data[j] = m_slices[i][j];
param.v.data[j] = v_slices[i][j];
}
}
Ok(())
}
#[cfg(feature = "rocm-hip")]
fn hip_adamw_available() -> bool {
use crate::backend::rocm::detect_local_rocm_hip;
detect_local_rocm_hip().available
}
fn compute_router_stats(router_weights: &Tensor<f32>) -> (f32, [f32; N_EXPERTS]) {
let b = match router_weights.meta.shape.dims.first() {
Some(Dim::Static(v)) => *v,
_ => 1,
};
if router_weights.data.len() < b * N_EXPERTS {
return ((N_EXPERTS as f32).ln(), [1.0 / N_EXPERTS as f32; N_EXPERTS]);
}
let mut counts = [0usize; N_EXPERTS];
let mut entropy_sum = 0.0f32;
for bi in 0..b {
let row_start = bi * N_EXPERTS;
let mut top_e = 0usize;
let mut top_v = router_weights.data[row_start];
for ei in 1..N_EXPERTS {
let v = router_weights.data[row_start + ei];
if v > top_v {
top_v = v;
top_e = ei;
}
}
counts[top_e] += 1;
let mut h = 0.0f32;
for ei in 0..N_EXPERTS {
let p = router_weights.data[row_start + ei];
if p > 0.0 {
h -= p * p.ln();
}
}
entropy_sum += h;
}
let bf = b.max(1) as f32;
let entropy_mean = entropy_sum / bf;
let per_expert_share = [
counts[0] as f32 / bf,
counts[1] as f32 / bf,
counts[2] as f32 / bf,
counts[3] as f32 / bf,
];
(entropy_mean, per_expert_share)
}
fn flatten_param_grads(grads: &[Tensor<f32>]) -> Vec<f32> {
let mut out: Vec<f32> = Vec::new();
for g in grads {
out.extend(g.data.iter().copied());
}
out
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"cosine_similarity: length mismatch ({} vs {})",
a.len(),
b.len()
);
let mut dot = 0.0f64;
let mut na = 0.0f64;
let mut nb = 0.0f64;
for i in 0..a.len() {
let ai = a[i] as f64;
let bi = b[i] as f64;
dot += ai * bi;
na += ai * ai;
nb += bi * bi;
}
if na == 0.0 && nb == 0.0 {
return 1.0;
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
return 0.0;
}
(dot / denom) as f32
}
fn apply_clip_scale(grads: &mut [Tensor<f32>], scale: f32) {
if scale == 1.0 {
return;
}
for g in grads {
for v in g.data.iter_mut() {
*v *= scale;
}
}
}
fn all_parameters_mut(model: &mut MoEModel) -> Vec<&mut Parameter> {
let mut out: Vec<&mut Parameter> = Vec::new();
out.extend(model.router.parameters_mut());
for expert in model.experts.iter_mut() {
out.extend(expert.parameters_mut());
}
debug_assert!(
out.len() >= 2,
"all_parameters_mut: expected at least router.weight + router.bias, got {}",
out.len()
);
out
}
fn build_checkpoint(model: &MoEModel, step: u32, seed: u64) -> Checkpoint {
let mut snapshots: Vec<ParameterSnapshot> = Vec::new();
let router_params = model.router.parameters();
for (i, p) in router_params.iter().enumerate() {
snapshots.push(snapshot_parameter(p, &format!("router.param_{i}")));
}
for (ei, expert) in model.experts.iter().enumerate() {
for (i, p) in expert.parameters().iter().enumerate() {
snapshots.push(snapshot_parameter(p, &format!("expert_{ei}.param_{i}")));
}
}
let config = format!(
"{{\"size\":\"{}\",\"seed\":{seed},\"steps\":{step}}}",
model.size.name()
);
Checkpoint {
step,
params: snapshots,
config,
}
}
fn unix_timestamp() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn write_arch_with_fingerprint(
path: &Path,
arch: &ModelArch,
fingerprint: &str,
param_count: u64,
) -> Result<()> {
let mut v = serde_json::to_value(arch)
.map_err(|e| Error::backend(format!("arch: serialize to value: {e}")))?;
if let Some(obj) = v.as_object_mut() {
obj.insert(
"fingerprint".to_string(),
serde_json::Value::String(fingerprint.to_string()),
);
obj.insert(
"param_count".to_string(),
serde_json::Value::Number(param_count.into()),
);
}
let text = serde_json::to_string_pretty(&v)
.map_err(|e| Error::backend(format!("arch: re-serialize: {e}")))?;
std::fs::write(path, text)
.map_err(|e| Error::backend(format!("arch: write {}: {e}", path.display())))?;
Ok(())
}
pub struct DiagnoseRow {
pub name: String,
pub fp16: Option<f32>,
pub fp16_sim: f32,
pub fp32: f32,
pub ratio: f32,
pub hip_ratio: Option<f32>,
pub abs_diff: Option<f32>,
pub rel_diff: Option<f32>,
pub cosine_sim: Option<f32>,
}
pub fn run_diagnose(cfg: &TrainConfig) -> Result<()> {
if cfg.batch_size == 0 {
return Err(Error::backend("run_diagnose: batch_size must be >= 1"));
}
let batch_size = if cfg.batch_size % 16 == 0 {
cfg.batch_size
} else {
eprintln!(
"[--diagnose] note: bumping batch {} -> 16 (Linear fp16 GEMM requires multiple of 16)",
cfg.batch_size
);
16
};
let dataset = build_dataset(&cfg.dataset_spec, cfg.seed)?;
if dataset.is_empty() {
return Err(Error::backend("run_diagnose: dataset is empty"));
}
let batch = dataset.batch(batch_size, 0);
let (xb, yb) = stack_batch(&batch);
let b = match &xb.meta.shape.dims[0] {
Dim::Static(v) => *v,
_ => return Err(Error::shape("run_diagnose: xb batch dim must be static")),
};
let in_dim = crate::moe_model::topology::IN_DIM;
let out_dim = crate::moe_model::topology::OUT_DIM;
if xb.data.len() != b * in_dim {
return Err(Error::shape(format!(
"run_diagnose: xb.len() {} != B*IN_DIM={}*{}",
xb.data.len(),
b,
in_dim
)));
}
if yb.data.len() != b * out_dim {
return Err(Error::shape(format!(
"run_diagnose: yb.len() {} != B*OUT_DIM={}*{}",
yb.data.len(),
b,
out_dim
)));
}
eprintln!(
"[--diagnose] model_size={} batch={} total_params={}",
cfg.size.name(),
b,
cfg.size.param_count()
);
let model = MoEModel::new(cfg.size, cfg.seed as u64);
let names = fp32_param_names(&model);
let mut fp16_per_param: Option<Vec<f32>> = None;
let mut fp16_grads_per_param: Option<Vec<Vec<f32>>> = None;
let mut fp16_top_k_indices: Option<Vec<usize>> = None;
let mut fp16_total_opt: Option<f32> = None;
let mut fp16_loss: Option<f32> = None;
let mut logits_fp16: Option<Vec<f32>> = None;
let mut grad_output_data: Option<Vec<f32>> = None;
let mut grad_output: Option<Tensor<f32>> = None;
match model.forward(&xb) {
Ok(output) => {
let logits = output.logits.data.clone();
if let Some(cache) = model.last_cache.borrow().as_ref() {
if let Some(router_cache) = cache.router_cache.downcast_ref::<RouterCache>() {
fp16_top_k_indices = Some(router_cache.top_k_indices.clone());
}
}
let mut grad_logits = vec![0.0f32; b * out_dim];
let mut loss = 0.0f32;
let nf = (b * out_dim) as f32;
for i in 0..(b * out_dim) {
let diff = logits[i] - yb.data[i];
loss += diff * diff;
grad_logits[i] = 2.0 * diff / nf;
}
loss /= nf;
let go = Tensor::dense_cpu(
output.logits.meta.domain.clone(),
output.logits.meta.shape.clone(),
grad_logits,
);
match model.backward(&go) {
Ok((_gi, param_grads_fp16)) => {
if param_grads_fp16.len() != names.len() {
return Err(Error::backend(format!(
"run_diagnose: fp16 param_grads.len() {} != names.len() {} (model.layer layout changed?)",
param_grads_fp16.len(),
names.len()
)));
}
fp16_per_param = Some(
param_grads_fp16
.iter()
.map(|g| grad_norm(&g.data))
.collect(),
);
fp16_grads_per_param =
Some(param_grads_fp16.iter().map(|g| g.data.clone()).collect());
fp16_total_opt = Some(grad_norm(&flatten_param_grads(¶m_grads_fp16)));
fp16_loss = Some(loss);
logits_fp16 = Some(logits);
grad_output_data = Some(go.data.clone());
grad_output = Some(go);
}
Err(e) => {
eprintln!(
"[--diagnose] fp16-via-HIP backward failed: {e}; \
continuing with fp32 reference only"
);
}
}
}
Err(e) => {
eprintln!(
"[--diagnose] fp16-via-HIP forward failed: {e}; \
continuing with fp32 reference only"
);
}
}
if grad_output_data.is_none() {
let ref_only = build_fp32_ref(&model);
let ref_res = fp32_forward_backward(&ref_only, &xb.data, &vec![0.0f32; b * out_dim]);
let _ = ref_res;
grad_output_data = Some(vec![0.0f32; b * out_dim]);
}
let ref_model = build_fp32_ref(&model);
let grad_for_fp32 = grad_output_data
.as_ref()
.expect("grad_output_data populated by fallback above");
let ref_result = fp32_forward_backward(&ref_model, &xb.data, grad_for_fp32);
let fp32_per_param: Vec<f32> = ref_result
.param_grads
.iter()
.map(|g| grad_norm_of(g))
.collect();
let fp32_total = grad_norm_of(
&ref_result
.param_grads
.iter()
.flat_map(|g| g.iter().copied())
.collect::<Vec<_>>(),
);
let sim_result = fp16_sim_forward_backward(&ref_model, &xb.data, grad_for_fp32);
let fp16_sim_per_param: Vec<f32> = sim_result
.param_grads
.iter()
.map(|g| grad_norm_of(g))
.collect();
let fp16_sim_total = grad_norm_of(
&sim_result
.param_grads
.iter()
.flat_map(|g| g.iter().copied())
.collect::<Vec<_>>(),
);
let mut max_abs = 0.0f32;
if let (Some(logits_h), ref_logits) = (&logits_fp16, &ref_result.logits) {
for (a, bv) in logits_h.iter().zip(ref_logits.iter()) {
let d = (a - bv).abs();
if d > max_abs {
max_abs = d;
}
}
}
let fp16_total = fp16_total_opt.unwrap_or(0.0);
let loss = fp16_loss.unwrap_or(0.0);
eprintln!(
"[--diagnose] forward: fp16 vs fp32 logits max_abs_diff={:.6} fp16_loss={:.6} \
fp16_total_grad_norm={:.6} fp16_sim_total_grad_norm={:.6} fp32_total_grad_norm={:.6}",
max_abs, loss, fp16_total, fp16_sim_total, fp32_total
);
let mut rows: Vec<DiagnoseRow> = Vec::with_capacity(names.len());
let mut ratio_max_sim = 0.0f32;
let mut ratio_max_hip: Option<f32> = None;
for (i, name) in names.iter().enumerate() {
let fp16_opt = fp16_per_param.as_ref().map(|v| v[i]);
let fp16_sim = fp16_sim_per_param[i];
let fp32 = fp32_per_param[i];
let sim_ratio = if fp32 > 0.0 {
fp16_sim / fp32
} else if fp16_sim == 0.0 {
1.0
} else {
f32::INFINITY
};
let hip_ratio_opt = fp16_opt.map(|h| {
if fp32 > 0.0 {
h / fp32
} else if h == 0.0 {
1.0
} else {
f32::INFINITY
}
});
if sim_ratio.is_finite() && sim_ratio > ratio_max_sim {
ratio_max_sim = sim_ratio;
}
if let Some(hip_ratio) = hip_ratio_opt {
let update = match ratio_max_hip {
None => true,
Some(cur) if cur.is_finite() && hip_ratio.is_finite() => hip_ratio > cur,
Some(_) => true, };
if update {
ratio_max_hip = Some(hip_ratio);
}
}
let rel_eps = 1.0e-12f32;
let (abs_diff, rel_diff, cosine_sim) = match fp16_opt {
Some(h) => {
let abs_d = (h - fp32).abs();
let rel_d = abs_d / fp32.abs().max(rel_eps);
let cos = match fp16_grads_per_param.as_ref() {
Some(grads) => {
let fp32_grad = &ref_result.param_grads[i];
if grads[i].len() == fp32_grad.len() && !grads[i].is_empty() {
Some(cosine_similarity(&grads[i], fp32_grad))
} else {
None
}
}
None => None,
};
(Some(abs_d), Some(rel_d), cos)
}
None => (None, None, None),
};
rows.push(DiagnoseRow {
name: name.clone(),
fp16: fp16_opt,
fp16_sim,
fp32,
ratio: sim_ratio,
hip_ratio: hip_ratio_opt,
abs_diff,
rel_diff,
cosine_sim,
});
}
let _ = grad_output; let name_w = rows.iter().map(|r| r.name.len()).max().unwrap_or(0);
eprintln!("[--diagnose] per-param grad_norm:");
if fp16_per_param.is_some() {
eprintln!(
" {:<name_w$} {:>12} {:>12} {:>12} {:>10} {:>10} {:>11} {:>10} {:>10}",
"param", "fp16(HIP)", "fp16(sim)", "fp32", "HIP/fp32", "sim/fp32", "|Δ|", "rel", "cos"
);
} else {
eprintln!(
" {:<name_w$} {:>12} {:>12} {:>10}",
"param", "fp16(sim)", "fp32", "sim/fp32"
);
}
let compact_ratio = |v: f32| -> String {
if !v.is_finite() {
"inf".to_string()
} else if v >= 1.0e6 {
">1e6".to_string()
} else {
format!("{:.2}", v)
}
};
let compact_diff = |v: Option<f32>| -> String {
match v {
None => "-".to_string(),
Some(x) if !x.is_finite() => "inf".to_string(),
Some(x) if x >= 1.0e6 => ">1e6".to_string(),
Some(x) => format!("{:.3e}", x),
}
};
let compact_rel = |v: Option<f32>| -> String {
match v {
None => "-".to_string(),
Some(x) if !x.is_finite() => "inf".to_string(),
Some(x) if x >= 1.0e6 => ">1e6".to_string(),
Some(x) => format!("{:.3e}", x),
}
};
let compact_cos = |v: Option<f32>| -> String {
match v {
None => "-".to_string(),
Some(x) if !x.is_finite() => "nan".to_string(),
Some(x) => {
format!("{:.4}", x)
}
}
};
for r in &rows {
if fp16_per_param.is_some() {
let fp16_str = match r.fp16 {
Some(v) => format!("{:>12.6e}", v),
None => format!("{:>12}", "(HIP n/a)"),
};
let fp32_str = format!("{:>12.6e}", r.fp32);
let fp16_sim_str = format!("{:>12.6e}", r.fp16_sim);
let hip_str = match r.hip_ratio {
Some(v) => format!("{:>10}", compact_ratio(v)),
None => format!("{:>10}", "-"),
};
let sim_str = format!("{:>10}", compact_ratio(r.ratio));
eprintln!(
" {:<name_w$} {} {} {} {} {} {:>11} {:>10} {:>10}",
r.name,
fp16_str,
fp16_sim_str,
fp32_str,
hip_str,
sim_str,
compact_diff(r.abs_diff),
compact_rel(r.rel_diff),
compact_cos(r.cosine_sim),
);
} else {
let fp16_sim_str = format!("{:>12.6e}", r.fp16_sim);
let fp32_str = format!("{:>12.6e}", r.fp32);
let sim_str = format!("{:>10}", compact_ratio(r.ratio));
eprintln!(
" {:<name_w$} {} {} {}",
r.name, fp16_sim_str, fp32_str, sim_str,
);
}
}
let top_k = crate::moe_model::topology::TOP_K;
let fp32_top_k = &ref_result.top_k_indices;
let sim_top_k = &sim_result.top_k_indices;
let fp16_top_k = fp16_top_k_indices.as_deref();
let (hip_diverge, hip_diverge_n) = match fp16_top_k {
Some(h) if h.len() == fp32_top_k.len() && !fp32_top_k.is_empty() => {
let mut n = 0usize;
for (bi, slot) in fp32_top_k.chunks(top_k).enumerate() {
let h_row = &h[bi * top_k..(bi + 1) * top_k];
let mut a: Vec<usize> = h_row.to_vec();
let mut b: Vec<usize> = slot.to_vec();
a.sort_unstable();
b.sort_unstable();
if a != b {
n += 1;
}
}
let rows = fp32_top_k.len() / top_k;
(Some((n, rows)), n)
}
_ => (None, 0),
};
let (sim_diverge, sim_diverge_n) =
if sim_top_k.len() == fp32_top_k.len() && !fp32_top_k.is_empty() {
let mut n = 0usize;
for (bi, slot) in fp32_top_k.chunks(top_k).enumerate() {
let s_row = &sim_top_k[bi * top_k..(bi + 1) * top_k];
let mut a: Vec<usize> = s_row.to_vec();
let mut b: Vec<usize> = slot.to_vec();
a.sort_unstable();
b.sort_unstable();
if a != b {
n += 1;
}
}
let rows = fp32_top_k.len() / top_k;
(Some((n, rows)), n)
} else {
(None, 0)
};
eprintln!("[--diagnose] router top-K assignments (B={}):", b);
eprintln!(
" {:<5} {:<24} {:<24} {:<24}",
"row", "fp16(HIP)", "fp16(sim)", "fp32"
);
for bi in 0..b {
let fp16_str = match fp16_top_k {
Some(h) => {
let row: Vec<String> = h[bi * top_k..(bi + 1) * top_k]
.iter()
.map(|e| format!("expert_{e}"))
.collect();
format!("[{}]", row.join(", "))
}
None => "(HIP n/a)".to_string(),
};
let sim_row: Vec<String> = sim_top_k[bi * top_k..(bi + 1) * top_k]
.iter()
.map(|e| format!("expert_{e}"))
.collect();
let fp32_row: Vec<String> = fp32_top_k[bi * top_k..(bi + 1) * top_k]
.iter()
.map(|e| format!("expert_{e}"))
.collect();
let sim_str = format!("[{}]", sim_row.join(", "));
let fp32_str = format!("[{}]", fp32_row.join(", "));
eprintln!(
" {:<5} {:<24} {:<24} {:<24}",
bi, fp16_str, sim_str, fp32_str
);
}
match hip_diverge {
Some((n, rows)) => {
eprintln!(
"[--diagnose] routing divergence: HIP fp16 vs fp32 = {n}/{rows} rows differ in top-K"
);
if n > 0 {
eprintln!(
"[--diagnose] [ROUTING DIVERGENT] HIP fp16 softmax picked different experts than fp32 on {n} of {rows} batch rows"
);
}
}
None => {
eprintln!(
"[--diagnose] routing divergence: HIP fp16 vs fp32 = n/a (HIP forward failed)"
);
}
}
match sim_diverge {
Some((n, rows)) => {
eprintln!(
"[--diagnose] routing divergence: sim fp16 vs fp32 = {n}/{rows} rows differ in top-K"
);
if n > 0 {
eprintln!(
"[--diagnose] [ROUTING DIVERGENT] simulated fp16 softmax picked different experts than fp32 on {n} of {rows} batch rows"
);
}
}
None => {
eprintln!("[--diagnose] routing divergence: sim fp16 vs fp32 = n/a");
}
}
let ratio_max = match ratio_max_hip {
Some(hip) if !hip.is_finite() || hip > ratio_max_sim => hip,
_ => ratio_max_sim,
};
let source = if ratio_max_hip.map_or(false, |h| h > ratio_max_sim) {
"HIP"
} else {
"sim"
};
let verdict = if !ratio_max.is_finite() {
"fp16/fp32 ratio is +inf (fp16 noise produced a non-zero grad where fp32 produced zero): noise IS a major issue"
} else if ratio_max <= 2.0 {
"fp16/fp32 ratio within 2x: noise is NOT a major issue"
} else if ratio_max <= 5.0 {
"fp16/fp32 ratio within 5x: noise is a moderate issue"
} else {
"fp16/fp32 ratio > 5x: noise IS a major issue"
};
let fmt_ratio = |r: f32| -> String {
if r.is_finite() {
format!("{:.2}", r)
} else {
"inf".to_string()
}
};
eprintln!(
"[--diagnose] verdict: {verdict} (worst-case ratio = {} from {}; \
sim={}, hip={}, routing_divergent_rows=HIP:{}/{} sim:{}/{})",
fmt_ratio(ratio_max),
source,
fmt_ratio(ratio_max_sim),
match ratio_max_hip {
Some(h) => fmt_ratio(h),
None => "n/a".to_string(),
},
hip_diverge_n,
hip_diverge.map(|(_, rows)| rows).unwrap_or(0),
sim_diverge_n,
sim_diverge.map(|(_, rows)| rows).unwrap_or(0),
);
Ok(())
}
#[cfg(test)]
mod lr_schedule_tests {
use super::*;
fn cfg() -> TrainConfig {
TrainConfig {
model_kind: ModelKind::MoE,
router_kind: RouterKind::SheafPadic,
size: MoESize::Nano,
steps: 20,
batch_size: 4,
lr: 1.0,
weight_decay: 0.0,
grad_clip: 1.0,
seed: 1,
checkpoint_dir: std::path::PathBuf::from("/tmp/lr-schedule-test"),
metrics_path: std::path::PathBuf::from("/tmp/lr-schedule-test/m.jsonl"),
arch_path: std::path::PathBuf::from("/tmp/lr-schedule-test/a.json"),
quiet: true,
dataset_spec: "empty".to_string(),
use_hip: false,
dry_run: true,
optimizer: Optimizer::Sgd,
momentum: 0.9,
router_lr_scale: 1.0,
diagnose: false,
warmup_steps: 4,
min_lr_ratio: 0.1,
}
}
#[test]
fn warmup_is_linear_from_zero() {
let c = cfg();
assert_eq!(c.lr_at(0), 0.0);
assert!((c.lr_at(1) - 0.25).abs() < 1e-6);
assert!((c.lr_at(2) - 0.50).abs() < 1e-6);
assert!((c.lr_at(3) - 0.75).abs() < 1e-6);
assert!((c.lr_at(4) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_decay_endpoint_is_min_lr() {
let c = cfg();
assert!((c.lr_at(20) - 0.1).abs() < 1e-6);
let lr_19 = c.lr_at(19);
assert!(
(lr_19 - 0.1087).abs() < 1e-3,
"lr_at(19) should be ~0.1087 (cosine interior near the end), got {lr_19}"
);
}
#[test]
fn cosine_decay_midpoint_is_average_of_base_and_min() {
let c = cfg();
assert!((c.lr_at(12) - 0.55).abs() < 1e-6);
}
#[test]
fn warmup_zero_skips_warmup_but_still_cosines() {
let mut c = cfg();
c.warmup_steps = 0;
assert_eq!(c.lr_at(0), 1.0);
let lr_1 = c.lr_at(1);
assert!(
(lr_1 - 0.9945).abs() < 1e-3,
"lr_at(1) with warmup=0 should be ~0.9945 (cosine interior near start), got {lr_1}"
);
}
#[test]
fn steps_at_or_below_warmup_hold_at_base_lr() {
let mut c = cfg();
c.steps = 4;
c.warmup_steps = 4;
assert!((c.lr_at(0) - 0.0).abs() < 1e-6);
assert!((c.lr_at(1) - 0.25).abs() < 1e-6);
assert!((c.lr_at(2) - 0.50).abs() < 1e-6);
assert!((c.lr_at(3) - 0.75).abs() < 1e-6);
assert!((c.lr_at(4) - 1.0).abs() < 1e-6);
}
#[test]
fn step_past_end_clamps_to_min_lr() {
let c = cfg();
assert!((c.lr_at(100) - 0.1).abs() < 1e-6);
}
}