Rust-LSTM
A comprehensive LSTM (Long Short-Term Memory) neural network library implemented in Rust. This library provides complete functionalities to create, train, and use LSTM networks for various sequence modeling tasks.
Features
- LSTM cell implementation with forward and backward propagation
- Peephole LSTM variant for enhanced performance
- Multi-layer LSTM networks with configurable architecture
- Complete training system with backpropagation through time (BPTT)
- Multiple optimizers: SGD, Adam, RMSprop
- Loss functions: MSE, MAE, Cross-entropy with softmax
- Training utilities: gradient clipping, validation, metrics tracking
- Random initialization of weights and biases
Getting Started
Prerequisites
Ensure you have Rust installed on your machine. If Rust is not already installed, you can install it by following the instructions on the official Rust website: https://www.rust-lang.org/tools/install.
Installing
To use Rust-LSTM in your project, add the following to your Cargo.toml:
[]
= "0.1.0"
Then, run the following command to build your project and download the Rust-LSTM crate:
Usage
Basic Forward Pass
Here's a simple example demonstrating basic LSTM usage:
use Array2;
use LSTMNetwork;
Training an LSTM Network
Here's how to train an LSTM for time series prediction:
use Array2;
use LSTMNetwork;
use ;
use Adam;
use MSELoss;
Advanced Features
Using Different Optimizers
use ;
// SGD optimizer
let sgd = SGD new;
// Adam optimizer with custom parameters
let adam = with_params;
// RMSprop optimizer
let rmsprop = new;
Different Loss Functions
use ;
// Mean Squared Error for regression
let mse_loss = MSELoss;
// Mean Absolute Error for robust regression
let mae_loss = MAELoss;
// Cross-Entropy for classification
let ce_loss = CrossEntropyLoss;
Peephole LSTM
use PeepholeLSTMCell;
let cell = new;
let = cell.forward;
To run this example, save it as main.rs, and run:
Examples
The library includes several examples demonstrating different use cases:
basic_usage.rs
- Simple forward pass exampletraining_example.rs
- Complete training workflow with multiple optimizerstime_series_prediction.rs
- Time series forecastingtext_generation.rs
- Character-level text generationmulti_layer_lstm.rs
- Multi-layer network usagepeephole.rs
- Peephole LSTM variant
Run examples with:
Architecture
The library is organized into several modules:
layers
: LSTM cell implementations (standard and peephole variants)models
: High-level network architecturesloss
: Loss functions for trainingoptimizers
: Optimization algorithmstraining
: Training utilities and trainer structutils
: Common utility functions
Training Features
- Backpropagation Through Time (BPTT): Complete gradient computation for sequence modeling
- Gradient Clipping: Prevents exploding gradients during training
- Multiple Optimizers: SGD, Adam, RMSprop with configurable parameters
- Validation Support: Track validation metrics during training
- Metrics Tracking: Loss curves and training progress monitoring
- Flexible Training Loop: Configurable epochs, learning rates, and logging
Running the Tests
To run the tests included with Rust-LSTM, execute:
This will run all the unit and integration tests defined in the library.
Contributing
Contributions to Rust-LSTM are welcome! Here are a few ways you can help:
- Report bugs and issues
- Suggest new features or improvements
- Open a pull request with improvements to code or documentation
- Please read CONTRIBUTING.md for details on our code of conduct and the process for submitting pull requests to us.
License
This project is licensed under the MIT License - see the LICENSE file for details.