trustformers-optim
Optimization algorithms and learning rate schedulers for training transformer models.
Current State
This crate provides comprehensive optimization infrastructure including state-of-the-art optimizers, learning rate schedulers, and distributed optimization techniques. It includes implementations of Adam, AdamW, SGD, LAMB, AdaFactor, and all three ZeRO optimization stages.
Features
Optimizers
- SGD: Stochastic Gradient Descent with momentum and weight decay
- Adam: Adaptive Moment Estimation optimizer
- AdamW: Adam with decoupled weight decay (recommended for transformers)
- LAMB: Layer-wise Adaptive Moments optimizer for large batch training
- AdaFactor: Memory-efficient optimizer with adaptive learning rates
Learning Rate Schedulers
- Linear: Linear warmup and decay
- Cosine: Cosine annealing with optional warmup
- Polynomial: Polynomial decay with configurable power
- Constant: Constant learning rate with optional warmup
- Exponential: Exponential decay
- Step: Step-wise learning rate reduction
Distributed Optimization
- ZeRO Stage 1: Optimizer state partitioning across GPUs
- ZeRO Stage 2: Optimizer state + gradient partitioning
- ZeRO Stage 3: Full parameter partitioning for maximum memory efficiency
- Gradient Synchronization: Efficient all-reduce operations
- Mixed Precision Support: Compatible with FP16/BF16 training
Advanced Features
- Gradient Clipping: By value or norm
- Weight Decay: L2 regularization and decoupled weight decay
- Momentum: Classical and Nesterov momentum
- Adaptive Learning Rates: Per-parameter learning rate adaptation
- Memory Optimization: Reduced memory footprint for large models
Usage Example
Basic Optimizer Usage
use ;
// Create AdamW optimizer
let config = AdamWConfig ;
let mut optimizer = new?;
// Create learning rate scheduler
let scheduler_config = SchedulerConfig ;
let scheduler = new;
// Training loop
for step in 0..num_steps
ZeRO Optimization
use ;
// Configure ZeRO
let zero_config = ZeroConfig ;
// Wrap optimizer with ZeRO
let base_optimizer = new?;
let optimizer = new?;
Architecture
trustformers-optim/
├── src/
│ ├── optimizers/ # Optimizer implementations
│ │ ├── sgd.rs # SGD optimizer
│ │ ├── adam.rs # Adam & AdamW
│ │ ├── lamb.rs # LAMB optimizer
│ │ └── adafactor.rs # AdaFactor optimizer
│ ├── schedulers/ # Learning rate schedulers
│ ├── distributed/ # Distributed optimization
│ │ ├── zero.rs # ZeRO implementation
│ │ └── utils.rs # Communication utilities
│ └── traits.rs # Core traits
Performance
Memory Savings with ZeRO
| Model Size | Standard | ZeRO-1 | ZeRO-2 | ZeRO-3 |
|---|---|---|---|---|
| 1.5B params | 24 GB | 16 GB | 12 GB | 8 GB |
| 7B params | 112 GB | 75 GB | 56 GB | 28 GB |
| 175B params | 2.8 TB | 1.9 TB | 1.4 TB | 700 GB |
Optimizer Performance
- AdamW: Industry standard for transformer training
- LAMB: Enables large batch training (up to 64K)
- AdaFactor: 75% memory reduction vs Adam
- ZeRO: Near-linear scaling across multiple GPUs
Best Practices
Choosing an Optimizer
- AdamW: Default choice for most transformer models
- LAMB: When using very large batch sizes
- AdaFactor: Memory-constrained environments
- SGD: Simple baseline, rarely optimal for transformers
Learning Rate Schedules
- Linear: Standard for BERT-style pre-training
- Cosine: Often better for fine-tuning
- Constant + Warmup: Simple and effective
- Polynomial: Alternative to linear decay
Hyperparameters
// Recommended starting points
AdamW: lr=5e-5, weight_decay=0.01, warmup=10% of steps
LAMB: lr=2e-3, weight_decay=0.01, warmup=10% of steps
AdaFactor: lr=1e-3, no weight_decay, warmup=10% of steps
Testing
- Unit tests for all optimizers and schedulers
- Convergence tests on toy problems
- Numerical stability tests
- Distributed operation tests
- Memory usage profiling
License
MIT OR Apache-2.0