rust-lstm 0.4.0

A complete LSTM neural network library with training capabilities, multiple optimizers, and peephole variants.
Documentation

Rust-LSTM

Crates.io Documentation License: MIT Rust

A comprehensive LSTM (Long Short-Term Memory) neural network library implemented in Rust with complete training capabilities, multiple optimizers, and advanced regularization.

Network Architecture Overview

graph TD
    A["Input Sequence<br/>(x₁, x₂, ..., xₜ)"] --> B["LSTM Layer 1"]
    B --> C["LSTM Layer 2"]
    C --> D["Output Layer"]
    D --> E["Predictions<br/>(y₁, y₂, ..., yₜ)"]
    
    F["Hidden State h₀"] --> B
    G["Cell State c₀"] --> B
    
    B --> H["Hidden State h₁"]
    B --> I["Cell State c₁"]
    
    H --> C
    I --> C
    
    style A fill:#e1f5fe
    style E fill:#e8f5e8
    style B fill:#fff3e0
    style C fill:#fff3e0

Features

  • LSTM, BiLSTM & GRU Networks with multi-layer support
  • Complete Training System with backpropagation through time (BPTT)
  • Multiple Optimizers: SGD, Adam, RMSprop with comprehensive learning rate scheduling
  • Advanced Learning Rate Scheduling: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial
  • Loss Functions: MSE, MAE, Cross-entropy with softmax
  • Advanced Dropout: Input, recurrent, output dropout, variational dropout, and zoneout
  • Schedule Visualization: ASCII visualization of learning rate schedules
  • Model Persistence: Save/load models in JSON or binary format
  • Peephole LSTM variant for enhanced performance

Quick Start

Add to your Cargo.toml:

[dependencies]
rust-lstm = "0.4.0"

Basic Usage

use ndarray::Array2;
use rust_lstm::models::lstm_network::LSTMNetwork;

fn main() {
    // Create LSTM network
    let mut network = LSTMNetwork::new(3, 10, 2); // input_size, hidden_size, num_layers
    
    // Create input data
    let input = Array2::from_shape_vec((3, 1), vec![0.5, 0.1, -0.3]).unwrap();
    let hx = Array2::zeros((10, 1));
    let cx = Array2::zeros((10, 1));
    
    // Forward pass
    let (output, _) = network.forward(&input, &hx, &cx);
    println!("Output: {:?}", output);
}

Training Example

use rust_lstm::{LSTMNetwork, create_basic_trainer, TrainingConfig};
use rust_lstm::optimizers::Adam;
use rust_lstm::loss::MSELoss;

fn main() {
    // Create network with dropout
    let network = LSTMNetwork::new(1, 10, 2)
        .with_input_dropout(0.2, true)
        .with_recurrent_dropout(0.3, true);
    
    // Setup trainer
    let mut trainer = create_basic_trainer(
        network,
        MSELoss,
        Adam::new(0.001)
    ).with_config(TrainingConfig {
        epochs: 100,
        clip_gradient: Some(1.0),
        ..Default::default()
    });
    
    // Train (train_data is Vec<(input, target)>)
    trainer.train(&train_data, Some(&validation_data));
}

Bidirectional LSTM

use rust_lstm::layers::bilstm_network::{BiLSTMNetwork, CombineMode};

// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);

// Process sequence with both past and future context
let outputs = bilstm.forward_sequence(&sequence);

BiLSTM Architecture

graph TD
    A["Input Sequence<br/>(x₁, x₂, x₃, x₄)"] --> B["Forward LSTM"]
    A --> C["Backward LSTM"]
    
    B --> D["Forward Hidden States<br/>(h₁→, h₂→, h₃→, h₄→)"]
    C --> E["Backward Hidden States<br/>(h₁←, h₂←, h₃←, h₄←)"]
    
    D --> F["Combine Layer<br/>(Concat/Sum/Average)"]
    E --> F
    
    F --> G["BiLSTM Output<br/>(combined representations)"]
    
    style A fill:#e1f5fe
    style B fill:#fff3e0
    style C fill:#fff3e0
    style F fill:#f3e5f5
    style G fill:#e8f5e8

GRU Networks

use rust_lstm::models::gru_network::GRUNetwork;

// Create GRU network (alternative to LSTM)
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
    .with_input_dropout(0.2, true)
    .with_recurrent_dropout(0.3, true);

// Forward pass
let (output, _) = gru.forward(&input, &hidden_state);

LSTM vs GRU Cell Comparison

graph LR
    subgraph "LSTM Cell"
        A1["Input xₜ"] --> B1["Forget Gate<br/>fₜ = σ(Wf·[hₜ₋₁,xₜ] + bf)"]
        A1 --> C1["Input Gate<br/>iₜ = σ(Wi·[hₜ₋₁,xₜ] + bi)"]
        A1 --> D1["Candidate Values<br/>C̃ₜ = tanh(WC·[hₜ₋₁,xₜ] + bC)"]
        A1 --> E1["Output Gate<br/>oₜ = σ(Wo·[hₜ₋₁,xₜ] + bo)"]
        
        B1 --> F1["Cell State<br/>Cₜ = fₜ * Cₜ₋₁ + iₜ * C̃ₜ"]
        C1 --> F1
        D1 --> F1
        
        F1 --> G1["Hidden State<br/>hₜ = oₜ * tanh(Cₜ)"]
        E1 --> G1
    end
    
    subgraph "GRU Cell"
        A2["Input xₜ"] --> B2["Reset Gate<br/>rₜ = σ(Wr·[hₜ₋₁,xₜ])"]
        A2 --> C2["Update Gate<br/>zₜ = σ(Wz·[hₜ₋₁,xₜ])"]
        A2 --> D2["Candidate State<br/>h̃ₜ = tanh(W·[rₜ*hₜ₋₁,xₜ])"]
        
        B2 --> D2
        C2 --> E2["Hidden State<br/>hₜ = (1-zₜ)*hₜ₋₁ + zₜ*h̃ₜ"]
        D2 --> E2
    end
    
    style B1 fill:#ffcdd2
    style C1 fill:#c8e6c9
    style D1 fill:#fff3e0
    style E1 fill:#e1f5fe
    style B2 fill:#ffcdd2
    style C2 fill:#c8e6c9
    style D2 fill:#fff3e0

Advanced Learning Rate Scheduling

The library includes 12 different learning rate schedulers with visualization capabilities:

use rust_lstm::{
    create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
    ScheduledOptimizer, PolynomialLR, CyclicalLR, WarmupScheduler,
    LRScheduleVisualizer, Adam
};

// Step decay: reduce LR by 50% every 10 epochs
let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5);

// OneCycle policy for modern deep learning
let mut trainer = create_one_cycle_trainer(network, 0.1, 100);

// Cosine annealing with warm restarts
let mut trainer = create_cosine_annealing_trainer(network, 0.01, 20, 1e-6);

// Advanced combinations - Warmup + Cyclical scheduling
let base_scheduler = CyclicalLR::new(0.001, 0.01, 10);
let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001);
let optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);

// Polynomial decay with visualization
let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001);
LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 100, 60, 10);

Available Schedulers:

  • ConstantLR: No scheduling (baseline)
  • StepLR: Step decay at regular intervals
  • MultiStepLR: Multi-step decay at specific milestones
  • ExponentialLR: Exponential decay each epoch
  • CosineAnnealingLR: Smooth cosine oscillation
  • CosineAnnealingWarmRestarts: Cosine with periodic restarts
  • OneCycleLR: One cycle policy for super-convergence
  • ReduceLROnPlateau: Adaptive reduction on validation plateaus
  • LinearLR: Linear interpolation between rates
  • PolynomialLR ✨: Polynomial decay with configurable power
  • CyclicalLR ✨: Triangular, triangular2, and exponential range modes
  • WarmupScheduler ✨: Gradual warmup wrapper for any base scheduler

Architecture

  • layers: LSTM and GRU cells (standard, peephole, bidirectional) with dropout
  • models: High-level network architectures (LSTM, BiLSTM, GRU)
  • training: Training utilities with automatic train/eval mode switching
  • optimizers: SGD, Adam, RMSprop with scheduling
  • loss: MSE, MAE, Cross-entropy loss functions
  • schedulers: Learning rate scheduling algorithms

Examples

Run examples to see the library in action:

# Basic usage and training
cargo run --example basic_usage
cargo run --example training_example
cargo run --example multi_layer_lstm
cargo run --example time_series_prediction

# Advanced architectures
cargo run --example gru_example              # GRU vs LSTM comparison
cargo run --example bilstm_example           # Bidirectional LSTM
cargo run --example dropout_example          # Comprehensive dropout demo

# Learning and scheduling
cargo run --example learning_rate_scheduling    # Basic schedulers
cargo run --example advanced_lr_scheduling      # Advanced schedulers with visualization

# Real-world applications
cargo run --example stock_prediction
cargo run --example weather_prediction
cargo run --example text_classification_bilstm
cargo run --example text_generation_advanced
cargo run --example real_data_example

# Analysis and debugging
cargo run --example model_inspection

Advanced Features

Dropout Types

  • Input Dropout: Applied to inputs before computing gates
  • Recurrent Dropout: Applied to hidden states with variational support
  • Output Dropout: Applied to layer outputs
  • Zoneout: RNN-specific regularization preserving previous states

Optimizers

  • SGD: Stochastic gradient descent with momentum
  • Adam: Adaptive moment estimation with bias correction
  • RMSprop: Root mean square propagation

Loss Functions

  • MSELoss: Mean squared error for regression
  • MAELoss: Mean absolute error for robust regression
  • CrossEntropyLoss: Numerically stable softmax cross-entropy for classification

Learning Rate Schedulers

  • StepLR: Decay by factor every N epochs
  • OneCycleLR: One cycle policy (warmup + annealing)
  • CosineAnnealingLR: Smooth cosine oscillation with warm restarts
  • ReduceLROnPlateau: Reduce when validation loss plateaus
  • PolynomialLR: Polynomial decay with configurable power
  • CyclicalLR: Triangular oscillation with multiple modes
  • WarmupScheduler: Gradual increase wrapper for any scheduler
  • LinearLR: Linear interpolation between learning rates

Testing

cargo test

Performance Examples

The library includes comprehensive examples that demonstrate its capabilities. Here are some suggested visualizations you can generate by running the examples:

Training Curves

Run the learning rate scheduling example to see different scheduler behaviors:

cargo run --example learning_rate_scheduling

Suggested visualization: Learning rate curves and loss curves for different schedulers

Architecture Comparison

Compare LSTM vs GRU performance:

cargo run --example gru_example

Suggested visualization: Training time and accuracy comparison charts

Real-world Applications

Generate prediction plots from the examples:

cargo run --example stock_prediction      # Stock price predictions
cargo run --example weather_prediction    # Weather forecasting  
cargo run --example text_classification_bilstm  # Classification accuracy

Suggested visualization: Prediction vs actual plots, confusion matrices, and accuracy metrics

Version History

  • v0.4.0: Advanced learning rate scheduling with 12 different schedulers, warmup support, cyclical learning rates, polynomial decay, and ASCII visualization
  • v0.3.0: Bidirectional LSTM networks with flexible combine modes
  • v0.2.0: Complete training system with BPTT and comprehensive dropout
  • v0.1.0: Initial LSTM implementation with forward pass

Contributing

Contributions are welcome! Please submit issues, feature requests, or pull requests.

License

MIT License - see the LICENSE file for details.