axonml-optim
Overview
axonml-optim provides optimization algorithms for training neural networks in the AxonML framework: five optimizers (SGD, Adam, AdamW, RMSprop, LAMB), seven learning-rate schedulers, a dynamic GradScaler for mixed-precision training, and a training health monitor that watches the run for NaNs, explosion/vanishing gradients, and stalled convergence.
Features
- SGD - Stochastic Gradient Descent with optional momentum, Nesterov acceleration, weight decay, and dampening.
- Adam - Adaptive Moment Estimation with bias correction and optional AMSGrad variant.
- AdamW - Adam with decoupled weight decay regularization for improved generalization.
- RMSprop - Root Mean Square Propagation with optional momentum and centered gradient normalization.
- LAMB - Layer-wise Adaptive Moments for large-batch training (32k+ batches); Adam plus a per-layer trust ratio.
- Learning Rate Schedulers -
StepLR,MultiStepLR,ExponentialLR,CosineAnnealingLR,OneCycleLR,WarmupLR, andReduceLROnPlateau. - GradScaler - Dynamic loss scaling for AMP; doubles on a healthy growth interval, halves on inf/NaN. Pairs with
autocast/AutocastGuardfromaxonml-autograd::amp. - Builder Pattern - Fluent API (
Adam::new(...).betas(...).eps(...).weight_decay(...).amsgrad(true)) for configuring optimizer hyperparameters. - Unified Interface - Common
Optimizertrait (step,zero_grad,get_lr,set_lr) for interoperability. - Fused Optimizer Loops - Adam, SGD, and RMSprop apply momentum, weight decay, and parameter updates in a single pass per tensor, reducing memory traffic.
- GPU-Resident State - Optimizer state (e.g. LAMB's
exp_avg/exp_avg_sq) is allocated on the same device as the parameter; no CPU round-trips. - Training Health Monitor -
TrainingMonitorrecords per-step loss, gradient norm, and LR; emitsTrainingAlerts (NaN, exploding/vanishing grad, stalled loss) withAlertSeverity; exportsHealthReportincludingLossTrendand a convergence score with suggested LR.
Modules
| Module | Description |
|---|---|
optimizer |
Core Optimizer trait |
sgd |
SGD with momentum, Nesterov, dampening, weight decay |
adam |
Adam and AdamW |
rmsprop |
RMSprop with optional centering and momentum |
lamb |
LAMB layer-wise adaptive moments |
lr_scheduler |
LRScheduler trait + seven concrete schedulers |
grad_scaler |
GradScaler / GradScalerState for AMP loss scaling |
health |
TrainingMonitor, MonitorConfig, HealthReport, TrainingAlert, AlertKind, AlertSeverity, LossTrend |
Cargo Features
| Feature | Purpose |
|---|---|
cuda |
Forwards CUDA support to axonml-core / axonml-tensor so optimizer state stays GPU-resident |
Usage
Add to your Cargo.toml:
[]
= "0.6.1"
Basic Training Loop
use *;
use ;
use Variable;
use Tensor;
// Create model
let model = new
.add
.add;
// Create optimizer
let mut optimizer = new;
let loss_fn = new;
// Training loop
for epoch in 0..100
SGD with Momentum
use ;
// Basic SGD
let mut optimizer = SGDnew;
// SGD with momentum
let mut optimizer = SGDnew
.momentum
.weight_decay
.nesterov;
Adam with Custom Configuration
use ;
// Adam with custom betas
let mut optimizer = new
.betas
.eps
.weight_decay
.amsgrad;
// AdamW for decoupled weight decay
let mut optimizer = new
.weight_decay;
LAMB for Large-Batch Training
use ;
let mut optimizer = LAMBnew;
// LAMB scales each parameter's update by a per-layer trust ratio,
// enabling stable training at batch sizes of 32k+.
Learning Rate Scheduling
use ;
let mut optimizer = SGDnew;
// Step decay every 10 epochs
let mut scheduler = new;
// Cosine annealing
let mut scheduler = new;
// One-cycle policy for super-convergence
let mut scheduler = new;
// In training loop
for epoch in 0..epochs
ReduceLROnPlateau
use ;
let mut optimizer = SGDnew;
let mut scheduler = with_options;
// Step with validation loss
scheduler.step_with_metric;
Mixed Precision (AMP) with GradScaler
use ;
use autocast;
use DType;
let mut optimizer = new;
let mut scaler = new;
for batch in batches
Training Health Monitor
The optimizer monitors its own training health — detects problems before they ruin the run.
use ;
let mut monitor = new;
// Record metrics each training step
for step in 0..1000
// Analyze training health
let report = monitor.health_report;
println!;
println!;
println!;
println!;
Tests
Run the test suite:
License
Licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT License (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Last updated: 2026-04-16 (v0.6.1)