rust_trainer 0.1.4

CPU-first pure-Rust supervised trainer for Selective State Space Models with Hyperspherical Prototype Networks.
Documentation
use ndarray::Array3;
use rand::rngs::StdRng;
use rand::SeedableRng;
use rust_trainer::layer::{backward, forward_with_cache};
use rust_trainer::trainer::{LayerSpec, MambaLayerParams};
use serde_json::json;

fn main() {
    let spec = LayerSpec {
        d_model: 16,
        d_state: 16,
        d_conv: 4,
    };
    let mut rng = StdRng::seed_from_u64(13);
    let layer = MambaLayerParams::random(spec, &mut rng);
    let x = Array3::from_shape_fn((2, 6, 16), |(b, t, d)| 0.01 * (1 + b + t + d) as f32);
    let dy = Array3::from_shape_fn((2, 6, 16), |(b, t, d)| 0.02 * (1 + b + t + d) as f32);
    let (y, cache) = forward_with_cache(&layer, x.view());
    let (dx, grads) = backward(&layer, dy.view(), &cache);

    let out = json!({
        "y_shape": y.dim(),
        "dx_shape": dx.dim(),
        "dx_norm": dx.iter().map(|v| v * v).sum::<f32>().sqrt(),
        "grad_out_proj_norm": grads.out_proj_w.iter().map(|v| v * v).sum::<f32>().sqrt(),
        "grad_x_proj_norm": grads.x_proj_w.iter().map(|v| v * v).sum::<f32>().sqrt(),
        "grad_conv_norm": grads.conv1d_w.iter().map(|v| v * v).sum::<f32>().sqrt(),
    });
    println!("{}", serde_json::to_string_pretty(&out).unwrap());
}