tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! End-to-end training runner for the 0.7B MoE model (gated on `rocm-hip`).
//!
//! Wires `synth_data` -> `dataset_bridge` -> `moe_model` ->
//! `training::step` -> `checkpoint` -> `model_arch` and runs the
//! full training loop. The default `cargo run --bin
//! train_quality_moe --features rocm-hip` produces
//! `arch.json` and `checkpoint.tkp1` at the end of the run.
//!
#![cfg(feature = "rocm-hip")]

//! End-to-end training runner for the 0.7B MoE quality-decision
//! model. See `docs/MOE_TRAINING_PLAN.md` for the kernel-level
//! walkthrough; this binary is the Phase 2.11 orchestration entry
//! point that wires `synth_data` / `dataset_bridge` / `moe_model` /
//! `checkpoint` / `model_arch` / `metrics` into a single CLI.

use std::env;
use std::fs;
use std::path::PathBuf;
use std::process::ExitCode;

use tokitai_operator::moe_model::MoESize;
use tokitai_operator::training_runner::{
    Optimizer, TrainConfig, default_steps, run_diagnose, run_training,
};

fn main() -> ExitCode {
    let args: Vec<String> = env::args().skip(1).collect();
    let parsed = match parse_args(&args) {
        Ok(c) => c,
        Err(e) => {
            eprintln!("train_quality_moe: {e}");
            eprintln!();
            print_usage();
            return ExitCode::from(2);
        }
    };
    if parsed.diagnose {
        return match run_diagnose(&parsed) {
            Ok(()) => ExitCode::SUCCESS,
            Err(e) => {
                eprintln!("train_quality_moe: diagnose failed: {e}");
                ExitCode::FAILURE
            }
        };
    }
    match run_training(&parsed) {
        Ok(summary) => {
            println!(
                "SUMMARY: model_size={} total_params={} final_loss={:.6} steps_run={} time_elapsed_sec={:.3} throughput_steps_per_sec={:.2} last_router_entropy={:.4} checkpoint={} arch={} fingerprint={}",
                summary.model_size,
                summary.total_params,
                summary.final_loss,
                summary.steps_run,
                summary.time_elapsed_sec,
                summary.throughput_steps_per_sec,
                summary.last_router_entropy_mean,
                summary.checkpoint_path.display(),
                summary.arch_path.display(),
                summary.arch_fingerprint,
            );
            // Validate the run. A non-finite loss (NaN/inf) or a short
            // step count means the runner swallowed a kernel/IO fault
            // and reported Ok anyway; surface that to the caller as a
            // non-zero exit. An empty metrics file is the same shape
            // of failure (no per-step rows were ever written).
            if !summary.final_loss.is_finite() {
                eprintln!(
                    "train_quality_moe: training failed: final_loss is not finite ({}); \
                     check for NaN/inf in forward/backward or HIP kernel faults",
                    summary.final_loss
                );
                return ExitCode::FAILURE;
            }
            if summary.steps_run != parsed.steps {
                eprintln!(
                    "train_quality_moe: training failed: steps_run={} but expected {}; \
                     training terminated early",
                    summary.steps_run, parsed.steps
                );
                return ExitCode::FAILURE;
            }
            match fs::metadata(&parsed.metrics_path) {
                Ok(md) if md.len() == 0 => {
                    eprintln!(
                        "train_quality_moe: training failed: metrics file {} is empty (0 bytes)",
                        parsed.metrics_path.display()
                    );
                    return ExitCode::FAILURE;
                }
                Ok(_) => {}
                Err(e) => {
                    eprintln!(
                        "train_quality_moe: training failed: cannot stat metrics file {}: {e}",
                        parsed.metrics_path.display()
                    );
                    return ExitCode::FAILURE;
                }
            }
            ExitCode::SUCCESS
        }
        Err(e) => {
            eprintln!("train_quality_moe: training failed: {e}");
            ExitCode::FAILURE
        }
    }
}

fn print_usage() {
    eprintln!(
        "Usage: train_quality_moe [options]\n\
         \n\
         Options:\n  \
           --dataset <spec>          synth:quality:N, synth:regression:N, or a SQLite-ledger dir\n  \
           --size nano|tiny|medium|full   model size (default: tiny). Nano is a ~195K-param\n                                         hyperparameter-iteration size (~1000x faster than\n                                         Tiny); the topology and I/O schema are identical.\n  \
           --steps <N>               total optimizer steps (default: {} for nano, {} for tiny, {} for medium/full)\n  \
           --batch <N>               batch size (default: 32)\n  \
           --lr <F>                  learning rate (default: 0.001 for AdamW, e.g. 0.05 for SGD)\n  \
           --optimizer adamw|sgd     optimizer (default: adamw). SGD uses --momentum and reuses\n                                    the AdamW `m` buffer as the velocity buffer; choose one\n                                    optimizer per run.\n  \
           --momentum <F>            momentum for --optimizer sgd (default: 0.9, ignored for AdamW)\n  \
           --router-lr-scale <F>     per-module LR scale for the router parameters\n  \
                                    (default: 0.1). AdamW normalises per-parameter, so the\n                                    router's per-element gradient (which is ~7.7x larger than\n                                    each expert's) drives a 7.7x-too-large update. Scaling the\n                                    router's effective LR compensates and prevents softmax\n                                    saturation in the first 1-2 steps.\n  \
           --weight-decay <F>        AdamW weight decay (default: 0.01, ignored for SGD)\n  \
           --grad-clip <F>           global L2 grad clip (default: 1.0)\n  \
           --seed <u32>              PRNG seed (default: 42)\n  \
           --checkpoint-dir <path>   output dir (default: ./var/training/<ts>)\n  \
           --metrics-path <path>     metrics JSONL (default: <checkpoint-dir>/metrics.jsonl)\n  \
           --arch-path <path>        arch JSON (default: <checkpoint-dir>/arch.json)\n  \
           --quiet                   suppress per-step stderr logs\n  \
           --dry-run                 skip the HIP forward/backward path; write synthetic metrics\n  \
                                    and a stub TKP1 checkpoint. Useful for CI / smoke tests where\n  \
                                    hipcc is unavailable or unreliable, and safe for the Full size\n  \
                                    (no ~10GB AdamW allocation)\n  \
           --diagnose                run ONE forward+backward step with both fp16 (HIP path) and\n  \
                                    fp32 (CPU reference) and print a per-parameter grad_norm\n  \
                                    comparison. Exits before the training loop. Used to test the\n  \
                                    hypothesis that residual grad_norm growth in the 0.7B MoE\n  \
                                    training is fp16 accumulation noise.",
        default_steps(MoESize::Nano),
        default_steps(MoESize::Tiny),
        default_steps(MoESize::Medium),
    );
}

fn parse_args(args: &[String]) -> Result<TrainConfig, String> {
    let mut size = MoESize::Tiny;
    let mut steps: Option<u32> = None;
    let mut batch_size: usize = 128;
    let mut lr: f32 = 0.001;
    let mut weight_decay: f32 = 0.01;
    let mut grad_clip: f32 = 1.0;
    let mut seed: u32 = 42;
    let mut dataset_spec: Option<String> = None;
    let mut checkpoint_dir: Option<PathBuf> = None;
    let mut metrics_path: Option<PathBuf> = None;
    let mut arch_path: Option<PathBuf> = None;
    let mut quiet = false;
    let mut dry_run = false;
    let mut diagnose = false;
    let mut optimizer: Optimizer = Optimizer::Adamw;
    let mut momentum: f32 = 0.9;
    let mut router_lr_scale: f32 = 0.1;
    let mut warmup_steps: u32 = 5;
    let mut min_lr_ratio: f32 = 0.1;

    let mut i = 0;
    while i < args.len() {
        let a = &args[i];
        match a.as_str() {
            "--dataset" => {
                dataset_spec = Some(next_value(args, i, "--dataset")?);
                i += 2;
            }
            "--size" => {
                let v = next_value(args, i, "--size")?;
                size = match v.as_str() {
                    "nano" | "Nano" => MoESize::Nano,
                    "tiny" | "Tiny" => MoESize::Tiny,
                    "medium" | "Medium" => MoESize::Medium,
                    "full" | "Full" => MoESize::Full,
                    other => return Err(format!("unknown --size value: {other}")),
                };
                i += 2;
            }
            "--steps" => {
                let v = next_value(args, i, "--steps")?;
                steps = Some(
                    v.parse::<u32>()
                        .map_err(|e| format!("parse --steps: {e}"))?,
                );
                i += 2;
            }
            "--batch" => {
                let v = next_value(args, i, "--batch")?;
                batch_size = v
                    .parse::<usize>()
                    .map_err(|e| format!("parse --batch: {e}"))?;
                if batch_size == 0 {
                    return Err("--batch 0 is invalid (must be > 0)".to_string());
                }
                if batch_size % 16 != 0 {
                    return Err(format!(
                        "--batch {batch_size} is not a multiple of 16; \
                         Linear fp16 GEMM requires the batch dim to be a multiple of 16. \
                         Use 16, 32, 64, 128, 256, ..."
                    ));
                }
                i += 2;
            }
            "--lr" => {
                let v = next_value(args, i, "--lr")?;
                lr = v.parse::<f32>().map_err(|e| format!("parse --lr: {e}"))?;
                if !lr.is_finite() || lr <= 0.0 {
                    return Err(format!("--lr {lr} is not a positive finite number"));
                }
                i += 2;
            }
            "--weight-decay" => {
                let v = next_value(args, i, "--weight-decay")?;
                weight_decay = v
                    .parse::<f32>()
                    .map_err(|e| format!("parse --weight-decay: {e}"))?;
                i += 2;
            }
            "--grad-clip" => {
                let v = next_value(args, i, "--grad-clip")?;
                grad_clip = v
                    .parse::<f32>()
                    .map_err(|e| format!("parse --grad-clip: {e}"))?;
                i += 2;
            }
            "--seed" => {
                let v = next_value(args, i, "--seed")?;
                seed = v.parse::<u32>().map_err(|e| format!("parse --seed: {e}"))?;
                i += 2;
            }
            "--checkpoint-dir" => {
                checkpoint_dir = Some(PathBuf::from(next_value(args, i, "--checkpoint-dir")?));
                i += 2;
            }
            "--metrics-path" => {
                metrics_path = Some(PathBuf::from(next_value(args, i, "--metrics-path")?));
                i += 2;
            }
            "--arch-path" => {
                arch_path = Some(PathBuf::from(next_value(args, i, "--arch-path")?));
                i += 2;
            }
            "--quiet" => {
                quiet = true;
                i += 1;
            }
            "--dry-run" => {
                dry_run = true;
                i += 1;
            }
            "--diagnose" => {
                diagnose = true;
                i += 1;
            }
            "--optimizer" => {
                let v = next_value(args, i, "--optimizer")?;
                optimizer = match v.as_str() {
                    "adamw" | "Adamw" | "ADAMW" => Optimizer::Adamw,
                    "sgd" | "Sgd" | "SGD" => Optimizer::Sgd,
                    other => {
                        return Err(format!(
                            "unknown --optimizer value: {other} (expected adamw or sgd)"
                        ));
                    }
                };
                i += 2;
            }
            "--momentum" => {
                let v = next_value(args, i, "--momentum")?;
                momentum = v
                    .parse::<f32>()
                    .map_err(|e| format!("parse --momentum: {e}"))?;
                if !momentum.is_finite() || !(0.0..=1.0).contains(&momentum) {
                    return Err(format!(
                        "--momentum {momentum} is out of [0, 1]; \
                         values >1.0 silently diverge the SGD velocity"
                    ));
                }
                i += 2;
            }
            "--router-lr-scale" => {
                let v = next_value(args, i, "--router-lr-scale")?;
                router_lr_scale = v
                    .parse::<f32>()
                    .map_err(|e| format!("parse --router-lr-scale: {e}"))?;
                if !router_lr_scale.is_finite() || !(0.0..=10.0).contains(&router_lr_scale) {
                    return Err(format!(
                        "--router-lr-scale {router_lr_scale} is out of [0, 10] or non-finite; \
                         expected a small positive scalar (0.01-1.0 typical, 0.0 freezes the router)"
                    ));
                }
                i += 2;
            }
            "--warmup-steps" => {
                let v = next_value(args, i, "--warmup-steps")?;
                warmup_steps = v
                    .parse::<u32>()
                    .map_err(|e| format!("parse --warmup-steps: {e}"))?;
                i += 2;
            }
            "--min-lr-ratio" => {
                let v = next_value(args, i, "--min-lr-ratio")?;
                min_lr_ratio = v
                    .parse::<f32>()
                    .map_err(|e| format!("parse --min-lr-ratio: {e}"))?;
                if !min_lr_ratio.is_finite() || !(0.0..=1.0).contains(&min_lr_ratio) {
                    return Err(format!(
                        "--min-lr-ratio {min_lr_ratio} is out of [0, 1] or non-finite"
                    ));
                }
                i += 2;
            }
            "-h" | "--help" => {
                print_usage();
                std::process::exit(0);
            }
            other => return Err(format!("unknown argument: {other}")),
        }
    }

    let checkpoint_dir = checkpoint_dir.unwrap_or_else(default_checkpoint_dir);
    let metrics_path = metrics_path.unwrap_or_else(|| checkpoint_dir.join("metrics.jsonl"));
    let arch_path = arch_path.unwrap_or_else(|| checkpoint_dir.join("arch.json"));
    let dataset_spec = dataset_spec.unwrap_or_else(|| "synth:quality:512".to_string());
    let steps = steps.unwrap_or_else(|| default_steps(size));

    // The default LR (0.001) is tuned for AdamW. SGD with momentum
    // typically needs 0.01-0.1; 0.001 will simply not move the model.
    // Warn (don't error) so the user can opt into a sub-optimal run
    // for an ablation study, but make the misconfig obvious.
    if optimizer == Optimizer::Sgd && lr < 0.01 {
        eprintln!(
            "train_quality_moe: WARNING --optimizer sgd with --lr {} is likely sub-optimal; \
             SGD typically needs --lr 0.05 (the 10K MLP gate used 0.05 with momentum 0.9). \
             Override --lr explicitly to silence this warning.",
            lr
        );
    }

    Ok(TrainConfig {
        // CLI defaults to the historical MoE + sheaf+padic
        // family. The 3-way ablation arms
        // (ModelKind::MoE+RouterKind::SoftmaxOnly, ModelKind::Dense)
        // are reachable through the new factory methods on
        // TrainConfig (e.g. TrainConfig::tiny_softmax_moe /
        // tiny_dense) — a future CLI flag will let the user
        // select between them, but for now the binary keeps
        // the historical default.
        model_kind: tokitai_operator::model_arch::ModelKind::MoE,
        router_kind: tokitai_operator::model_arch::RouterKind::SheafPadic,
        size,
        steps,
        batch_size,
        lr,
        weight_decay,
        grad_clip,
        seed,
        checkpoint_dir,
        metrics_path,
        arch_path,
        quiet,
        dataset_spec,
        use_hip: std::env::var("TOKITAI_TRAIN_HIP")
            .ok()
            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
            .unwrap_or(false),
        dry_run,
        optimizer,
        momentum,
        router_lr_scale,
        diagnose,
        warmup_steps,
        min_lr_ratio,
    })
}

fn next_value(args: &[String], i: usize, flag: &str) -> Result<String, String> {
    args.get(i + 1)
        .cloned()
        .ok_or_else(|| format!("{flag} requires a value"))
}

fn default_checkpoint_dir() -> PathBuf {
    let ts = std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .map(|d| d.as_secs())
        .unwrap_or(0);
    PathBuf::from(format!("./var/training/{ts}"))
}