rust_trainer 0.1.3

CPU-first pure-Rust supervised trainer for Selective State Space Models with Hyperspherical Prototype Networks.
Documentation

RUST Trainer

CI Crates.io License: Apache-2.0

A CPU-first Rust training package implementing a Mamba SSM + Hyperspherical Prototype Network (HPN) architecture.

This is a concrete, working reference implementation — not a blank framework. The model is a stack of Mamba selective state-space layers with an HPN cosine-distance output head and learnable prototype matrix. Teams can use it as-is or fork and replace the layer/loss internals for their own architecture.

It works as both:

  • a library dependency in your Rust application
  • a ready-to-run CLI training binary

No Python runtime is required.


What this package gives you

Capability Status
End-to-end train loop binary
Library API for embedding custom model/training logic
Serializable optimizer state (AdamW)
Resume-safe checkpoints (model + optimizer + step)
JSONL metrics logging
Configurable layer expansion and freezing
Deterministic parity probe for save/load correctness
SIMD math kernels for high-throughput CPU training
Validation cadence + best-checkpoint tracking
Early stopping support
Gradient clipping controls
LR warmup + cosine decay controls
Non-finite update guardrails
Sharded streaming dataset support
Packed sequence batching on shard streams
Multi-worker sharded prefetch ingestion
Run-state resume for stream cursors
Atomic versioned checkpoints
Cross-framework parity harness (Rust vs Python/JAX)

Production readiness status

Current state: production-candidate for single-node CPU training with production-critical ingestion and parity validation implemented.

What is already robust:

  • deterministic resume behavior (checkpoint + optimizer state + step)
  • deterministic resume behavior for streaming shard cursors (run_state.json)
  • configurable expansion/freeze controls for staged training
  • validated SIMD and backward kernels with scalar parity probes
  • CI, release, and crate packaging automation

Operational note:

  • cross-framework parity runner requires jax to be installed in the active Python environment

Detailed roadmap and release milestones are tracked in roadmap.md.


Design philosophy

  • Keep trainer internals explicit and hackable.
  • Favor reproducible runs and resumability.
  • Make the package easy to fork and specialize for custom architectures.
  • Keep data ingestion simple at first (integer token files), then scale to streaming pipelines.

Repository layout

src/
  lib.rs            - crate root and public exports
  generic_trainer.rs  - full trainer state, train step, checkpoint/resume
  trainer.rs        - parameter and expansion/freezing config types
  optim.rs          - AdamW optimizer primitives
  nn.rs             - layer norm and output-loss helpers
  simd_ops.rs       - SIMD kernels used by the model path
  layer.rs          - cached layer forward/backward helpers
  stack.rs          - stack-level supervised step helpers
src/bin/
  train_generic.rs    - main CLI trainer
  trainer_parity.rs   - deterministic parity/resume checker
  parity_lab.rs     - expansion/freeze behavior harness
  *_probe.rs        - low-level probes used for validation

Quick start

git clone https://github.com/npradeep357/rust_trainer
cd rust_trainer
cargo test

Run a short smoke training job:

cargo run --release --bin train_generic -- \
  --steps 200 \
  --batch-size 4 \
  --seq-len 32 \
  --out-dir runs/smoke

Run deterministic resume parity check:

cargo run --release --bin trainer_parity

Run Rust vs Python/JAX parity check:

cargo run --release --bin cross_framework_parity

Train your own model data

The default trainer accepts a whitespace-separated integer token file.

cargo run --release --bin train_generic -- \
  --token-file /path/to/your_tokens.txt \
  --out-dir runs/experiment_v1 \
  --steps 50000 \
  --batch-size 8 \
  --seq-len 64 \
  --d-model 512 \
  --d-state 16 \
  --base-layers 2 \
  --target-layers 6 \
  --placement specific:1,3,4,5 \
  --freeze first:2 \
  --lr 1e-4

Resume training:

cargo run --release --bin train_generic -- \
  --resume runs/experiment_v1/latest.bincode \
  --out-dir runs/experiment_v1 \
  --steps 20000

CLI reference

Flag Default Description
--out-dir PATH runs/ Output directory for checkpoints and metrics
--steps N 5000 Number of train steps
--save-every N 200 Checkpoint interval
--log-every N 20 Metric logging interval
--batch-size N 8 Batch size
--seq-len N 64 Sequence length
--seed N 42 RNG seed
--base-layers N 2 Initial layer count before expansion
--target-layers N 6 Final layer count after expansion
--d-model N 512 Hidden width
--d-state N 16 State width
--d-conv N 4 Convolution kernel width
--placement STR specific:1,3,4,5 Expansion placement
--freeze STR first:2 Freeze policy
--lr F 1e-4 AdamW learning rate
--ff-lr F 1e-4 Forward-Forward local learning rate for d_skip updates
--bp-cadence-steps N 32 Apply global BP every N train steps (FF runs each step)
--gradient-surgery-method STR pcgrad Conflict handling method: pcgrad, gradnorm, cagradstep
--gradient-surgery-epsilon F 1e-8 Numerical stability epsilon for surgery operations
--gradnorm-alpha F 0.2 GradNorm disagreement scaling factor
--cagrad-lambda F 1.0 CAGradStep conflict-aversion strength
--freeze-embedding 1 false Freeze embedding table
--token-file PATH none Integer token dataset
--token-dir PATH none Directory of shard files for streaming training
--val-token-file PATH none Optional dedicated validation token dataset
--val-token-dir PATH none Optional validation shard directory
--shard-ext EXT txt Extension filter used with --token-dir / --val-token-dir
--shuffle-shards 1 true Shuffle shard order each epoch in streaming mode
--packed-sequences 1 true Use packed contiguous token windows in streaming mode
--prefetch-workers N 0 Number of worker threads for sharded prefetch (>1 enables multi-worker mode)
--prefetch-buffer N 16 Bounded channel capacity for prefetched worker batches
--resume PATH none Resume checkpoint
--vocab-size N auto Override vocab size
--val-ratio F 0.05 Validation split ratio when --val-token-file is not provided
--val-every N 200 Validation cadence in train steps
--eval-batches N 8 Number of validation batches per eval pass
--early-stopping-patience N 0 Stop when validation does not improve for N eval windows (0 disables)
--grad-clip-norm F 0.0 Global gradient clipping threshold (0 disables clipping)
--fail-on-non-finite 1 false Panic on NaN/Inf detection instead of skipping the update
--lr-warmup-steps N 0 Linear warmup length before decay
--lr-min-scale F 0.1 Minimum LR floor as fraction of base LR for cosine decay

Debug + recovery artifacts

  • latest.bincode: latest atomic, versioned checkpoint (model + optimizer + step)
  • best.bincode: best validation checkpoint
  • run_state.json: resumable data-pipeline state (in-memory cursor or shard stream cursor)
  • metrics.jsonl: train/validation metrics stream for dashboards and debugging

Placement values

Value Meaning
append Add new layers at the end
prepend Add new layers at the beginning
insert:N Insert all new layers starting at index N
specific:1,3,4,5 Place each new layer at specific final indices

Freeze values

Value Meaning
first:N Freeze first N layers
indices:0,2,5 Freeze explicit layer indices

Use as a library

Add dependency:

[dependencies]
rust_trainer = "0.1"

Use the package name from your own Cargo.toml.

Minimal integration example:

use rust_trainer::generic_trainer::{
    GenericTrainer, default_trainer_config, make_batch_from_tokens,
};
use rust_trainer::{ExpansionPlacement, FreezeSelection, LayerSpec};

let spec = LayerSpec { d_model: 512, d_state: 16, d_conv: 4 };
let cfg = default_trainer_config(
    8192,
    spec,
    6,
    ExpansionPlacement::SpecificPositions(vec![1, 3, 4, 5]),
    FreezeSelection::FirstN(2),
    false,
    1e-4,
);

let mut trainer = GenericTrainer::new_random(cfg, 2, 42);
let tokens: Vec<i64> = (0..8192).collect();
let (ids, targets) = make_batch_from_tokens(&tokens, 0, 8, 64);
let stats = trainer.train_step(&ids, &targets);
println!("loss: {}", stats.loss);
trainer.save_checkpoint("checkpoint.bincode").unwrap();

Architecture

The default model uses:

Component Implementation
Sequence layers Mamba SSM (causal conv1d + SiLU + discretized state scan)
Output head Hyperspherical Prototype Network (HPN)
Loss Squared cosine distance to nearest prototype
Optimizer AdamW with serializable moment buffers
Inference path CPU-only; no GPU required

Customize for your own architecture

The package is designed to be forked for other architectures. Replace or extend:

  1. Layer forward/backward path in src/layer.rs — swap Mamba for Transformer, LSTM, etc.
  2. Output loss/head logic in src/nn.rs — swap HPN for cross-entropy, contrastive loss, etc.
  3. Trainer state wiring in src/generic_trainer.rs — add or remove parameter groups
  4. Data loading logic in src/bin/train_generic.rs

The checkpointing, optimizer state, logging, expansion, and freeze infrastructure are all architecture-independent and can be kept as-is.


Release flow

Releases are tag-driven via GitHub Actions.

# bump version in Cargo.toml, commit, then:
git tag v0.2.0
git push origin v0.2.0

The release workflow runs tests, builds binaries, creates a GitHub Release, and can publish to crates.io when credentials are configured.


License

Apache-2.0. See LICENSE.