rust_lstm/
lib.rs

1//! # Rust LSTM Library
2//! 
3//! A complete LSTM implementation with training capabilities, multiple optimizers,
4//! dropout regularization, and support for various architectures including peephole 
5//! connections and bidirectional processing.
6//! 
7//! ## Core Components
8//! 
9//! - **LSTM Cells**: Standard and peephole LSTM implementations with full backpropagation
10//! - **Bidirectional LSTM**: Process sequences in both directions with flexible output combination
11//! - **Networks**: Multi-layer LSTM networks for sequence modeling
12//! - **Training**: Complete training system with BPTT, gradient clipping, and validation
13//! - **Optimizers**: SGD, Adam, and RMSprop optimizers with adaptive learning rates
14//! - **Loss Functions**: MSE, MAE, and Cross-Entropy with numerically stable implementations
15//! - **Dropout**: Input, recurrent, output dropout and zoneout regularization
16//! 
17//! ## Quick Start
18//! 
19//! ```rust
20//! use rust_lstm::models::lstm_network::LSTMNetwork;
21//! use rust_lstm::training::create_basic_trainer;
22//! 
23//! // Create a 2-layer LSTM with 10 input features and 20 hidden units
24//! let mut network = LSTMNetwork::new(10, 20, 2)
25//!     .with_input_dropout(0.2, true)     // Variational input dropout
26//!     .with_recurrent_dropout(0.3, true) // Variational recurrent dropout
27//!     .with_output_dropout(0.1);         // Standard output dropout
28//! 
29//! let mut trainer = create_basic_trainer(network, 0.001);
30//! 
31//! // Train on your data
32//! // trainer.train(&train_data, Some(&validation_data));
33//! ```
34
35/// Main library module.
36pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45// Re-export commonly used items
46pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use layers::linear::{LinearLayer, LinearGradients};
54pub use training::{
55    LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
56    EarlyStoppingConfig, EarlyStoppingMetric, EarlyStopper,
57    create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
58    create_basic_batch_trainer, create_adam_batch_trainer
59};
60pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
61pub use schedulers::{
62    LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR, 
63    CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR, 
64    ReduceLROnPlateau, LinearLR, AnnealStrategy,
65    PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
66    LRScheduleVisualizer
67};
68pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
69pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use ndarray::arr2;
75    
76    #[test]
77    fn test_library_integration() {
78        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
79        let input = arr2(&[[1.0], [0.5]]);
80        let hx = arr2(&[[0.0], [0.0], [0.0]]);
81        let cx = arr2(&[[0.0], [0.0], [0.0]]);
82        
83        let (hy, cy) = network.forward(&input, &hx, &cx);
84        
85        assert_eq!(hy.shape(), &[3, 1]);
86        assert_eq!(cy.shape(), &[3, 1]);
87    }
88
89    #[test]
90    fn test_library_with_dropout() {
91        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
92            .with_input_dropout(0.2, false)
93            .with_recurrent_dropout(0.3, true)
94            .with_output_dropout(0.1);
95        
96        let input = arr2(&[[1.0], [0.5]]);
97        let hx = arr2(&[[0.0], [0.0], [0.0]]);
98        let cx = arr2(&[[0.0], [0.0], [0.0]]);
99        
100        // Test training mode
101        network.train();
102        let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
103        
104        // Test evaluation mode
105        network.eval();
106        let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
107        
108        assert_eq!(hy_train.shape(), &[3, 1]);
109        assert_eq!(cy_train.shape(), &[3, 1]);
110        assert_eq!(hy_eval.shape(), &[3, 1]);
111        assert_eq!(cy_eval.shape(), &[3, 1]);
112    }
113}