#![cfg(feature = "rocm-hip")]
use std::env;
use std::fs;
use std::path::PathBuf;
use std::process::ExitCode;
use tokitai_operator::moe_model::MoESize;
use tokitai_operator::training_runner::{
Optimizer, TrainConfig, default_steps, run_diagnose, run_training,
};
fn main() -> ExitCode {
let args: Vec<String> = env::args().skip(1).collect();
let parsed = match parse_args(&args) {
Ok(c) => c,
Err(e) => {
eprintln!("train_quality_moe: {e}");
eprintln!();
print_usage();
return ExitCode::from(2);
}
};
if parsed.diagnose {
return match run_diagnose(&parsed) {
Ok(()) => ExitCode::SUCCESS,
Err(e) => {
eprintln!("train_quality_moe: diagnose failed: {e}");
ExitCode::FAILURE
}
};
}
match run_training(&parsed) {
Ok(summary) => {
println!(
"SUMMARY: model_size={} total_params={} final_loss={:.6} steps_run={} time_elapsed_sec={:.3} throughput_steps_per_sec={:.2} last_router_entropy={:.4} checkpoint={} arch={} fingerprint={}",
summary.model_size,
summary.total_params,
summary.final_loss,
summary.steps_run,
summary.time_elapsed_sec,
summary.throughput_steps_per_sec,
summary.last_router_entropy_mean,
summary.checkpoint_path.display(),
summary.arch_path.display(),
summary.arch_fingerprint,
);
if !summary.final_loss.is_finite() {
eprintln!(
"train_quality_moe: training failed: final_loss is not finite ({}); \
check for NaN/inf in forward/backward or HIP kernel faults",
summary.final_loss
);
return ExitCode::FAILURE;
}
if summary.steps_run != parsed.steps {
eprintln!(
"train_quality_moe: training failed: steps_run={} but expected {}; \
training terminated early",
summary.steps_run, parsed.steps
);
return ExitCode::FAILURE;
}
match fs::metadata(&parsed.metrics_path) {
Ok(md) if md.len() == 0 => {
eprintln!(
"train_quality_moe: training failed: metrics file {} is empty (0 bytes)",
parsed.metrics_path.display()
);
return ExitCode::FAILURE;
}
Ok(_) => {}
Err(e) => {
eprintln!(
"train_quality_moe: training failed: cannot stat metrics file {}: {e}",
parsed.metrics_path.display()
);
return ExitCode::FAILURE;
}
}
ExitCode::SUCCESS
}
Err(e) => {
eprintln!("train_quality_moe: training failed: {e}");
ExitCode::FAILURE
}
}
}
fn print_usage() {
eprintln!(
"Usage: train_quality_moe [options]\n\
\n\
Options:\n \
--dataset <spec> synth:quality:N, synth:regression:N, or a SQLite-ledger dir\n \
--size nano|tiny|medium|full model size (default: tiny). Nano is a ~195K-param\n hyperparameter-iteration size (~1000x faster than\n Tiny); the topology and I/O schema are identical.\n \
--steps <N> total optimizer steps (default: {} for nano, {} for tiny, {} for medium/full)\n \
--batch <N> batch size (default: 32)\n \
--lr <F> learning rate (default: 0.001 for AdamW, e.g. 0.05 for SGD)\n \
--optimizer adamw|sgd optimizer (default: adamw). SGD uses --momentum and reuses\n the AdamW `m` buffer as the velocity buffer; choose one\n optimizer per run.\n \
--momentum <F> momentum for --optimizer sgd (default: 0.9, ignored for AdamW)\n \
--router-lr-scale <F> per-module LR scale for the router parameters\n \
(default: 0.1). AdamW normalises per-parameter, so the\n router's per-element gradient (which is ~7.7x larger than\n each expert's) drives a 7.7x-too-large update. Scaling the\n router's effective LR compensates and prevents softmax\n saturation in the first 1-2 steps.\n \
--weight-decay <F> AdamW weight decay (default: 0.01, ignored for SGD)\n \
--grad-clip <F> global L2 grad clip (default: 1.0)\n \
--seed <u32> PRNG seed (default: 42)\n \
--checkpoint-dir <path> output dir (default: ./var/training/<ts>)\n \
--metrics-path <path> metrics JSONL (default: <checkpoint-dir>/metrics.jsonl)\n \
--arch-path <path> arch JSON (default: <checkpoint-dir>/arch.json)\n \
--quiet suppress per-step stderr logs\n \
--dry-run skip the HIP forward/backward path; write synthetic metrics\n \
and a stub TKP1 checkpoint. Useful for CI / smoke tests where\n \
hipcc is unavailable or unreliable, and safe for the Full size\n \
(no ~10GB AdamW allocation)\n \
--diagnose run ONE forward+backward step with both fp16 (HIP path) and\n \
fp32 (CPU reference) and print a per-parameter grad_norm\n \
comparison. Exits before the training loop. Used to test the\n \
hypothesis that residual grad_norm growth in the 0.7B MoE\n \
training is fp16 accumulation noise.",
default_steps(MoESize::Nano),
default_steps(MoESize::Tiny),
default_steps(MoESize::Medium),
);
}
fn parse_args(args: &[String]) -> Result<TrainConfig, String> {
let mut size = MoESize::Tiny;
let mut steps: Option<u32> = None;
let mut batch_size: usize = 128;
let mut lr: f32 = 0.001;
let mut weight_decay: f32 = 0.01;
let mut grad_clip: f32 = 1.0;
let mut seed: u32 = 42;
let mut dataset_spec: Option<String> = None;
let mut checkpoint_dir: Option<PathBuf> = None;
let mut metrics_path: Option<PathBuf> = None;
let mut arch_path: Option<PathBuf> = None;
let mut quiet = false;
let mut dry_run = false;
let mut diagnose = false;
let mut optimizer: Optimizer = Optimizer::Adamw;
let mut momentum: f32 = 0.9;
let mut router_lr_scale: f32 = 0.1;
let mut warmup_steps: u32 = 5;
let mut min_lr_ratio: f32 = 0.1;
let mut i = 0;
while i < args.len() {
let a = &args[i];
match a.as_str() {
"--dataset" => {
dataset_spec = Some(next_value(args, i, "--dataset")?);
i += 2;
}
"--size" => {
let v = next_value(args, i, "--size")?;
size = match v.as_str() {
"nano" | "Nano" => MoESize::Nano,
"tiny" | "Tiny" => MoESize::Tiny,
"medium" | "Medium" => MoESize::Medium,
"full" | "Full" => MoESize::Full,
other => return Err(format!("unknown --size value: {other}")),
};
i += 2;
}
"--steps" => {
let v = next_value(args, i, "--steps")?;
steps = Some(
v.parse::<u32>()
.map_err(|e| format!("parse --steps: {e}"))?,
);
i += 2;
}
"--batch" => {
let v = next_value(args, i, "--batch")?;
batch_size = v
.parse::<usize>()
.map_err(|e| format!("parse --batch: {e}"))?;
if batch_size == 0 {
return Err("--batch 0 is invalid (must be > 0)".to_string());
}
if batch_size % 16 != 0 {
return Err(format!(
"--batch {batch_size} is not a multiple of 16; \
Linear fp16 GEMM requires the batch dim to be a multiple of 16. \
Use 16, 32, 64, 128, 256, ..."
));
}
i += 2;
}
"--lr" => {
let v = next_value(args, i, "--lr")?;
lr = v.parse::<f32>().map_err(|e| format!("parse --lr: {e}"))?;
if !lr.is_finite() || lr <= 0.0 {
return Err(format!("--lr {lr} is not a positive finite number"));
}
i += 2;
}
"--weight-decay" => {
let v = next_value(args, i, "--weight-decay")?;
weight_decay = v
.parse::<f32>()
.map_err(|e| format!("parse --weight-decay: {e}"))?;
i += 2;
}
"--grad-clip" => {
let v = next_value(args, i, "--grad-clip")?;
grad_clip = v
.parse::<f32>()
.map_err(|e| format!("parse --grad-clip: {e}"))?;
i += 2;
}
"--seed" => {
let v = next_value(args, i, "--seed")?;
seed = v.parse::<u32>().map_err(|e| format!("parse --seed: {e}"))?;
i += 2;
}
"--checkpoint-dir" => {
checkpoint_dir = Some(PathBuf::from(next_value(args, i, "--checkpoint-dir")?));
i += 2;
}
"--metrics-path" => {
metrics_path = Some(PathBuf::from(next_value(args, i, "--metrics-path")?));
i += 2;
}
"--arch-path" => {
arch_path = Some(PathBuf::from(next_value(args, i, "--arch-path")?));
i += 2;
}
"--quiet" => {
quiet = true;
i += 1;
}
"--dry-run" => {
dry_run = true;
i += 1;
}
"--diagnose" => {
diagnose = true;
i += 1;
}
"--optimizer" => {
let v = next_value(args, i, "--optimizer")?;
optimizer = match v.as_str() {
"adamw" | "Adamw" | "ADAMW" => Optimizer::Adamw,
"sgd" | "Sgd" | "SGD" => Optimizer::Sgd,
other => {
return Err(format!(
"unknown --optimizer value: {other} (expected adamw or sgd)"
));
}
};
i += 2;
}
"--momentum" => {
let v = next_value(args, i, "--momentum")?;
momentum = v
.parse::<f32>()
.map_err(|e| format!("parse --momentum: {e}"))?;
if !momentum.is_finite() || !(0.0..=1.0).contains(&momentum) {
return Err(format!(
"--momentum {momentum} is out of [0, 1]; \
values >1.0 silently diverge the SGD velocity"
));
}
i += 2;
}
"--router-lr-scale" => {
let v = next_value(args, i, "--router-lr-scale")?;
router_lr_scale = v
.parse::<f32>()
.map_err(|e| format!("parse --router-lr-scale: {e}"))?;
if !router_lr_scale.is_finite() || !(0.0..=10.0).contains(&router_lr_scale) {
return Err(format!(
"--router-lr-scale {router_lr_scale} is out of [0, 10] or non-finite; \
expected a small positive scalar (0.01-1.0 typical, 0.0 freezes the router)"
));
}
i += 2;
}
"--warmup-steps" => {
let v = next_value(args, i, "--warmup-steps")?;
warmup_steps = v
.parse::<u32>()
.map_err(|e| format!("parse --warmup-steps: {e}"))?;
i += 2;
}
"--min-lr-ratio" => {
let v = next_value(args, i, "--min-lr-ratio")?;
min_lr_ratio = v
.parse::<f32>()
.map_err(|e| format!("parse --min-lr-ratio: {e}"))?;
if !min_lr_ratio.is_finite() || !(0.0..=1.0).contains(&min_lr_ratio) {
return Err(format!(
"--min-lr-ratio {min_lr_ratio} is out of [0, 1] or non-finite"
));
}
i += 2;
}
"-h" | "--help" => {
print_usage();
std::process::exit(0);
}
other => return Err(format!("unknown argument: {other}")),
}
}
let checkpoint_dir = checkpoint_dir.unwrap_or_else(default_checkpoint_dir);
let metrics_path = metrics_path.unwrap_or_else(|| checkpoint_dir.join("metrics.jsonl"));
let arch_path = arch_path.unwrap_or_else(|| checkpoint_dir.join("arch.json"));
let dataset_spec = dataset_spec.unwrap_or_else(|| "synth:quality:512".to_string());
let steps = steps.unwrap_or_else(|| default_steps(size));
if optimizer == Optimizer::Sgd && lr < 0.01 {
eprintln!(
"train_quality_moe: WARNING --optimizer sgd with --lr {} is likely sub-optimal; \
SGD typically needs --lr 0.05 (the 10K MLP gate used 0.05 with momentum 0.9). \
Override --lr explicitly to silence this warning.",
lr
);
}
Ok(TrainConfig {
model_kind: tokitai_operator::model_arch::ModelKind::MoE,
router_kind: tokitai_operator::model_arch::RouterKind::SheafPadic,
size,
steps,
batch_size,
lr,
weight_decay,
grad_clip,
seed,
checkpoint_dir,
metrics_path,
arch_path,
quiet,
dataset_spec,
use_hip: std::env::var("TOKITAI_TRAIN_HIP")
.ok()
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false),
dry_run,
optimizer,
momentum,
router_lr_scale,
diagnose,
warmup_steps,
min_lr_ratio,
})
}
fn next_value(args: &[String], i: usize, flag: &str) -> Result<String, String> {
args.get(i + 1)
.cloned()
.ok_or_else(|| format!("{flag} requires a value"))
}
fn default_checkpoint_dir() -> PathBuf {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
PathBuf::from(format!("./var/training/{ts}"))
}