Rust-LSTM
A comprehensive LSTM (Long Short-Term Memory) neural network library implemented in Rust with complete training capabilities, multiple optimizers, and advanced regularization.
Network Architecture Overview
graph TD
A["Input Sequence<br/>(x₁, x₂, ..., xₜ)"] --> B["LSTM Layer 1"]
B --> C["LSTM Layer 2"]
C --> D["Output Layer"]
D --> E["Predictions<br/>(y₁, y₂, ..., yₜ)"]
F["Hidden State h₀"] --> B
G["Cell State c₀"] --> B
B --> H["Hidden State h₁"]
B --> I["Cell State c₁"]
H --> C
I --> C
style A fill:#e1f5fe
style E fill:#e8f5e8
style B fill:#fff3e0
style C fill:#fff3e0
Features
- LSTM, BiLSTM & GRU Networks with multi-layer support
- Complete Training System with backpropagation through time (BPTT)
- Multiple Optimizers: SGD, Adam, RMSprop with comprehensive learning rate scheduling
- Advanced Learning Rate Scheduling: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial
- Early Stopping: Prevent overfitting with configurable patience and metric monitoring
- Loss Functions: MSE, MAE, Cross-entropy with softmax
- Advanced Dropout: Input, recurrent, output dropout, variational dropout, and zoneout
- Batch Processing: 4-5x training speedup with efficient batch operations
- Schedule Visualization: ASCII visualization of learning rate schedules
- Model Persistence: Save/load models in JSON or binary format
- Peephole LSTM variant for enhanced performance
Quick Start
Add to your Cargo.toml
:
[]
= "0.5.0"
Basic Usage
use Array2;
use LSTMNetwork;
Training Example
use ;
Early Stopping
use ;
Bidirectional LSTM
use ;
// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
let mut bilstm = new_concat;
// Process sequence with both past and future context
let outputs = bilstm.forward_sequence;
BiLSTM Architecture
graph TD
A["Input Sequence<br/>(x₁, x₂, x₃, x₄)"] --> B["Forward LSTM"]
A --> C["Backward LSTM"]
B --> D["Forward Hidden States<br/>(h₁→, h₂→, h₃→, h₄→)"]
C --> E["Backward Hidden States<br/>(h₁←, h₂←, h₃←, h₄←)"]
D --> F["Combine Layer<br/>(Concat/Sum/Average)"]
E --> F
F --> G["BiLSTM Output<br/>(combined representations)"]
style A fill:#e1f5fe
style B fill:#fff3e0
style C fill:#fff3e0
style F fill:#f3e5f5
style G fill:#e8f5e8
GRU Networks
use GRUNetwork;
// Create GRU network (alternative to LSTM)
let mut gru = new
.with_input_dropout
.with_recurrent_dropout;
// Forward pass
let = gru.forward;
LSTM vs GRU Cell Comparison
graph LR
subgraph "LSTM Cell"
A1["Input xₜ"] --> B1["Forget Gate<br/>fₜ = σ(Wf·[hₜ₋₁,xₜ] + bf)"]
A1 --> C1["Input Gate<br/>iₜ = σ(Wi·[hₜ₋₁,xₜ] + bi)"]
A1 --> D1["Candidate Values<br/>C̃ₜ = tanh(WC·[hₜ₋₁,xₜ] + bC)"]
A1 --> E1["Output Gate<br/>oₜ = σ(Wo·[hₜ₋₁,xₜ] + bo)"]
B1 --> F1["Cell State<br/>Cₜ = fₜ * Cₜ₋₁ + iₜ * C̃ₜ"]
C1 --> F1
D1 --> F1
F1 --> G1["Hidden State<br/>hₜ = oₜ * tanh(Cₜ)"]
E1 --> G1
end
subgraph "GRU Cell"
A2["Input xₜ"] --> B2["Reset Gate<br/>rₜ = σ(Wr·[hₜ₋₁,xₜ])"]
A2 --> C2["Update Gate<br/>zₜ = σ(Wz·[hₜ₋₁,xₜ])"]
A2 --> D2["Candidate State<br/>h̃ₜ = tanh(W·[rₜ*hₜ₋₁,xₜ])"]
B2 --> D2
C2 --> E2["Hidden State<br/>hₜ = (1-zₜ)*hₜ₋₁ + zₜ*h̃ₜ"]
D2 --> E2
end
style B1 fill:#ffcdd2
style C1 fill:#c8e6c9
style D1 fill:#fff3e0
style E1 fill:#e1f5fe
style B2 fill:#ffcdd2
style C2 fill:#c8e6c9
style D2 fill:#fff3e0
Advanced Learning Rate Scheduling
The library includes 12 different learning rate schedulers with visualization capabilities:
use ;
// Create a network
let network = new;
// Step decay: reduce LR by 50% every 10 epochs
let mut trainer = create_step_lr_trainer;
// OneCycle policy for modern deep learning
let mut trainer = create_one_cycle_trainer;
// Cosine annealing with warm restarts
let mut trainer = create_cosine_annealing_trainer;
// Advanced combinations - Warmup + Cyclical scheduling
let base_scheduler = new;
let warmup_scheduler = new;
let optimizer = new;
// Polynomial decay with visualization
let poly_scheduler = new;
print_schedule;
Available Schedulers:
- ConstantLR: No scheduling (baseline)
- StepLR: Step decay at regular intervals
- MultiStepLR: Multi-step decay at specific milestones
- ExponentialLR: Exponential decay each epoch
- CosineAnnealingLR: Smooth cosine oscillation
- CosineAnnealingWarmRestarts: Cosine with periodic restarts
- OneCycleLR: One cycle policy for super-convergence
- ReduceLROnPlateau: Adaptive reduction on validation plateaus
- LinearLR: Linear interpolation between rates
- PolynomialLR ✨: Polynomial decay with configurable power
- CyclicalLR ✨: Triangular, triangular2, and exponential range modes
- WarmupScheduler ✨: Gradual warmup wrapper for any base scheduler
Architecture
layers
: LSTM and GRU cells (standard, peephole, bidirectional) with dropoutmodels
: High-level network architectures (LSTM, BiLSTM, GRU)training
: Training utilities with automatic train/eval mode switchingoptimizers
: SGD, Adam, RMSprop with schedulingloss
: MSE, MAE, Cross-entropy loss functionsschedulers
: Learning rate scheduling algorithms
Examples
Run examples to see the library in action:
# Basic usage and training
# Advanced architectures
# Learning and scheduling
# Performance and batch processing
# Real-world applications
# Analysis and debugging
Advanced Features
Dropout Types
- Input Dropout: Applied to inputs before computing gates
- Recurrent Dropout: Applied to hidden states with variational support
- Output Dropout: Applied to layer outputs
- Zoneout: RNN-specific regularization preserving previous states
Optimizers
- SGD: Stochastic gradient descent with momentum
- Adam: Adaptive moment estimation with bias correction
- RMSprop: Root mean square propagation
Loss Functions
- MSELoss: Mean squared error for regression
- MAELoss: Mean absolute error for robust regression
- CrossEntropyLoss: Numerically stable softmax cross-entropy for classification
Learning Rate Schedulers
- StepLR: Decay by factor every N epochs
- OneCycleLR: One cycle policy (warmup + annealing)
- CosineAnnealingLR: Smooth cosine oscillation with warm restarts
- ReduceLROnPlateau: Reduce when validation loss plateaus
- PolynomialLR: Polynomial decay with configurable power
- CyclicalLR: Triangular oscillation with multiple modes
- WarmupScheduler: Gradual increase wrapper for any scheduler
- LinearLR: Linear interpolation between learning rates
Testing
Performance Examples
The library includes comprehensive examples that demonstrate its capabilities:
Training with Different Schedulers
Run the learning rate scheduling examples to see different scheduler behaviors:
Architecture Comparison
Compare LSTM vs GRU performance:
Real-world Applications
Test the library with practical examples:
The examples output training metrics, loss values, and predictions that you can analyze or plot with external tools.
Version History
- v0.4.0: Advanced learning rate scheduling with 12 different schedulers, warmup support, cyclical learning rates, polynomial decay, and ASCII visualization
- v0.3.0: Bidirectional LSTM networks with flexible combine modes
- v0.2.0: Complete training system with BPTT and comprehensive dropout
- v0.1.0: Initial LSTM implementation with forward pass
Contributing
Contributions are welcome! Please submit issues, feature requests, or pull requests.
License
MIT License - see the LICENSE file for details.