rust_trainer 0.1.4

CPU-first pure-Rust supervised trainer for Selective State Space Models with Hyperspherical Prototype Networks.
Documentation
use rust_trainer::generic_trainer::{
    default_trainer_config, is_frozen_unchanged, make_batch_from_tokens, mean_layer_norm,
    GenericTrainer,
};
use rust_trainer::{ExpansionPlacement, FreezeSelection, LayerSpec};
use serde_json::json;

fn main() {
    let spec = LayerSpec {
        d_model: 16,
        d_state: 16,
        d_conv: 4,
    };
    let cfg = default_trainer_config(
        64,
        spec,
        6,
        ExpansionPlacement::SpecificPositions(vec![1, 3, 4, 5]),
        FreezeSelection::FirstN(2),
        false,
        1e-4,
    );

    let mut a = GenericTrainer::new_random(cfg, 2, 2026);
    let tokens = (0..1024).map(|v| (v % 64) as i64).collect::<Vec<_>>();
    let before = a.layer_l2_norms();
    let (ids1, tgt1) = make_batch_from_tokens(&tokens, 0, 2, 8);
    let s1 = a.train_step(&ids1, &tgt1);
    let mid = a.layer_l2_norms();
    let frozen_ok = is_frozen_unchanged(&before, &mid, &a.frozen_layer_indices, 1e-8);

    let ckpt = std::env::temp_dir().join("trainer_parity_roundtrip.bincode");
    a.save_checkpoint(&ckpt).unwrap();
    let mut b = GenericTrainer::load_checkpoint(&ckpt).unwrap();

    let (ids2, tgt2) = make_batch_from_tokens(&tokens, 16, 2, 8);
    let sa = a.train_step(&ids2, &tgt2);
    let sb = b.train_step(&ids2, &tgt2);

    let emb_delta = (&a.params.embedding - &b.params.embedding)
        .mapv(f32::abs)
        .sum();
    let proto_delta = (&a.prototypes - &b.prototypes).mapv(f32::abs).sum();
    let out = json!({
        "first_step_loss": s1.loss,
        "resume_step_loss_a": sa.loss,
        "resume_step_loss_b": sb.loss,
        "resume_loss_abs_diff": (sa.loss - sb.loss).abs(),
        "embedding_abs_diff_after_resume_step": emb_delta,
        "prototypes_abs_diff_after_resume_step": proto_delta,
        "frozen_layers_unchanged": frozen_ok,
        "mean_layer_norm_before": mean_layer_norm(&before),
        "mean_layer_norm_after_one_step": mean_layer_norm(&mid),
    });
    println!("{}", serde_json::to_string_pretty(&out).unwrap());
    let _ = std::fs::remove_file(&ckpt);
}