rust_trainer 0.1.4

CPU-first pure-Rust supervised trainer for Selective State Space Models with Hyperspherical Prototype Networks.
Documentation
use rust_trainer::{
    AdamWConfig, ExpansionConfig, ExpansionPlacement, ExperimentalTrainer,
    ExperimentalTrainerConfig, FreezeSelection, LayerSpec, TrainerParams,
};
use serde_json::json;
use std::env;

fn parse_args() -> (
    ExpansionPlacement,
    FreezeSelection,
    usize,
    usize,
    u64,
    usize,
) {
    let mut placement = ExpansionPlacement::Append;
    let mut freeze = FreezeSelection::FirstN(2);
    let mut target = 6usize;
    let mut base_layers = 2usize;
    let mut seed = 42u64;
    let mut cycles = 1usize;

    let args = env::args().skip(1).collect::<Vec<_>>();
    let mut idx = 0usize;
    while idx < args.len() {
        match args[idx].as_str() {
            "--placement" if idx + 1 < args.len() => {
                placement = parse_placement(&args[idx + 1]);
                idx += 2;
            }
            "--freeze" if idx + 1 < args.len() => {
                freeze = parse_freeze(&args[idx + 1]);
                idx += 2;
            }
            "--target" if idx + 1 < args.len() => {
                target = args[idx + 1].parse().unwrap_or(6);
                idx += 2;
            }
            "--base-layers" if idx + 1 < args.len() => {
                base_layers = args[idx + 1].parse().unwrap_or(2);
                idx += 2;
            }
            "--seed" if idx + 1 < args.len() => {
                seed = args[idx + 1].parse().unwrap_or(42);
                idx += 2;
            }
            "--cycles" if idx + 1 < args.len() => {
                cycles = args[idx + 1].parse().unwrap_or(1);
                idx += 2;
            }
            _ => {
                idx += 1;
            }
        }
    }

    (placement, freeze, target, base_layers, seed, cycles)
}

fn parse_placement(raw: &str) -> ExpansionPlacement {
    if raw == "append" {
        return ExpansionPlacement::Append;
    }
    if raw == "prepend" {
        return ExpansionPlacement::Prepend;
    }
    if let Some(value) = raw.strip_prefix("insert:") {
        return ExpansionPlacement::InsertAt {
            index: value.parse().unwrap_or(0),
        };
    }
    if let Some(value) = raw.strip_prefix("specific:") {
        let positions = value
            .split(',')
            .filter_map(|item| item.parse::<usize>().ok())
            .collect::<Vec<_>>();
        return ExpansionPlacement::SpecificPositions(positions);
    }
    ExpansionPlacement::Append
}

fn parse_freeze(raw: &str) -> FreezeSelection {
    if let Some(value) = raw.strip_prefix("first:") {
        return FreezeSelection::FirstN(value.parse().unwrap_or(2));
    }
    if let Some(value) = raw.strip_prefix("indices:") {
        let indices = value
            .split(',')
            .filter_map(|item| item.parse::<usize>().ok())
            .collect::<Vec<_>>();
        return FreezeSelection::Indices(indices);
    }
    FreezeSelection::FirstN(2)
}

fn main() {
    let (placement, freeze, target, base_layers, seed, cycles) = parse_args();
    let spec = LayerSpec {
        d_model: 64,
        d_state: 16,
        d_conv: 4,
    };

    let base = TrainerParams::random(512, spec, base_layers, seed);
    let cfg = ExperimentalTrainerConfig {
        vocab_size: 512,
        layer_spec: spec,
        expansion: ExpansionConfig {
            target_num_layers: target,
            placement,
        },
        freeze_selection: freeze,
        freeze_embedding: true,
        ff_lr: 1e-2,
        ff_threshold: 1e-3,
        adamw: AdamWConfig::default(),
    };
    let mut trainer = ExperimentalTrainer::from_base(base, cfg);
    let before = trainer.layer_norms();
    let d = spec.d_model;
    for _ in 0..cycles {
        let n = trainer.expanded_layer_count();
        // Dummy activations: pos=+0.5, neg=-0.5 (parity test only cares about freeze)
        let h_pos: Vec<ndarray::Array1<f32>> = (0..n)
            .map(|_| ndarray::Array1::from_elem(d, 0.5_f32))
            .collect();
        let h_neg: Vec<ndarray::Array1<f32>> = (0..n)
            .map(|_| ndarray::Array1::from_elem(d, -0.5_f32))
            .collect();
        let _ = trainer.train_ff_cycle(&h_pos, &h_neg);
    }
    let after = trainer.layer_norms();

    let frozen_unchanged = trainer
        .frozen_layer_indices()
        .iter()
        .all(|idx| (before[*idx] - after[*idx]).abs() <= 1e-8);

    let out = json!({
        "expanded_layers": trainer.expanded_layer_count(),
        "frozen_indices": trainer.frozen_layer_indices(),
        "steps": trainer.step,
        "frozen_unchanged": frozen_unchanged,
        "before_norms": before,
        "after_norms": after,
    });
    println!("{}", serde_json::to_string_pretty(&out).unwrap());
}