ghostflow_nn/
rnn.rs

1//! Recurrent Neural Network Layers
2//!
3//! Implements LSTM, GRU, and basic RNN cells for sequence modeling.
4
5use ghostflow_core::Tensor;
6use crate::module::Module;
7use crate::linear::Linear;
8
9/// LSTM Cell - Long Short-Term Memory
10///
11/// Implements the LSTM equations:
12/// - i_t = σ(W_ii * x_t + b_ii + W_hi * h_(t-1) + b_hi)  [input gate]
13/// - f_t = σ(W_if * x_t + b_if + W_hf * h_(t-1) + b_hf)  [forget gate]
14/// - g_t = tanh(W_ig * x_t + b_ig + W_hg * h_(t-1) + b_hg)  [cell gate]
15/// - o_t = σ(W_io * x_t + b_io + W_ho * h_(t-1) + b_ho)  [output gate]
16/// - c_t = f_t ⊙ c_(t-1) + i_t ⊙ g_t  [cell state]
17/// - h_t = o_t ⊙ tanh(c_t)  [hidden state]
18pub struct LSTMCell {
19    input_size: usize,
20    hidden_size: usize,
21    
22    // Input-to-hidden weights (combined for all gates)
23    w_ih: Linear,
24    // Hidden-to-hidden weights (combined for all gates)
25    w_hh: Linear,
26    
27    training: bool,
28}
29
30impl LSTMCell {
31    /// Create a new LSTM cell
32    ///
33    /// # Arguments
34    /// * `input_size` - Size of input features
35    /// * `hidden_size` - Size of hidden state
36    pub fn new(input_size: usize, hidden_size: usize) -> Self {
37        LSTMCell {
38            input_size,
39            hidden_size,
40            // 4 * hidden_size for i, f, g, o gates
41            w_ih: Linear::new(input_size, 4 * hidden_size),
42            w_hh: Linear::new(hidden_size, 4 * hidden_size),
43            training: true,
44        }
45    }
46
47    /// Forward pass through LSTM cell
48    ///
49    /// # Arguments
50    /// * `input` - Input tensor of shape [batch, input_size]
51    /// * `hidden` - Previous hidden state [batch, hidden_size]
52    /// * `cell` - Previous cell state [batch, hidden_size]
53    ///
54    /// # Returns
55    /// Tuple of (new_hidden, new_cell)
56    pub fn forward_cell(&self, input: &Tensor, hidden: &Tensor, cell: &Tensor) -> (Tensor, Tensor) {
57        let batch_size = input.dims()[0];
58        
59        // Compute all gates at once
60        let gates = self.w_ih.forward(input)
61            .add(&self.w_hh.forward(hidden))
62            .unwrap();
63        
64        let gates_data = gates.data_f32();
65        let hidden_data = cell.data_f32();
66        
67        let mut new_cell_data = vec![0.0f32; batch_size * self.hidden_size];
68        let mut new_hidden_data = vec![0.0f32; batch_size * self.hidden_size];
69        
70        for b in 0..batch_size {
71            for h in 0..self.hidden_size {
72                let base_idx = b * 4 * self.hidden_size;
73                
74                // Extract gates
75                let i = sigmoid(gates_data[base_idx + h]);  // input gate
76                let f = sigmoid(gates_data[base_idx + self.hidden_size + h]);  // forget gate
77                let g = tanh(gates_data[base_idx + 2 * self.hidden_size + h]);  // cell gate
78                let o = sigmoid(gates_data[base_idx + 3 * self.hidden_size + h]);  // output gate
79                
80                // Update cell state
81                let c_prev = hidden_data[b * self.hidden_size + h];
82                let c_new = f * c_prev + i * g;
83                new_cell_data[b * self.hidden_size + h] = c_new;
84                
85                // Update hidden state
86                new_hidden_data[b * self.hidden_size + h] = o * tanh(c_new);
87            }
88        }
89        
90        let new_hidden = Tensor::from_slice(&new_hidden_data, &[batch_size, self.hidden_size]).unwrap();
91        let new_cell = Tensor::from_slice(&new_cell_data, &[batch_size, self.hidden_size]).unwrap();
92        
93        (new_hidden, new_cell)
94    }
95}
96
97impl Module for LSTMCell {
98    fn forward(&self, input: &Tensor) -> Tensor {
99        let batch_size = input.dims()[0];
100        let hidden = Tensor::zeros(&[batch_size, self.hidden_size]);
101        let cell = Tensor::zeros(&[batch_size, self.hidden_size]);
102        let (h, _) = self.forward_cell(input, &hidden, &cell);
103        h
104    }
105
106    fn parameters(&self) -> Vec<Tensor> {
107        let mut params = self.w_ih.parameters();
108        params.extend(self.w_hh.parameters());
109        params
110    }
111
112    fn train(&mut self) { self.training = true; }
113    fn eval(&mut self) { self.training = false; }
114    fn is_training(&self) -> bool { self.training }
115}
116
117/// LSTM Layer - processes entire sequences
118pub struct LSTM {
119    cell: LSTMCell,
120    num_layers: usize,
121    bidirectional: bool,
122    dropout: f32,
123    training: bool,
124}
125
126impl LSTM {
127    /// Create a new LSTM layer
128    ///
129    /// # Arguments
130    /// * `input_size` - Size of input features
131    /// * `hidden_size` - Size of hidden state
132    /// * `num_layers` - Number of stacked LSTM layers
133    /// * `bidirectional` - Whether to use bidirectional LSTM
134    /// * `dropout` - Dropout probability between layers
135    pub fn new(
136        input_size: usize,
137        hidden_size: usize,
138        num_layers: usize,
139        bidirectional: bool,
140        dropout: f32,
141    ) -> Self {
142        LSTM {
143            cell: LSTMCell::new(input_size, hidden_size),
144            num_layers,
145            bidirectional,
146            dropout,
147            training: true,
148        }
149    }
150
151    /// Forward pass through LSTM
152    ///
153    /// # Arguments
154    /// * `input` - Input tensor of shape [batch, seq_len, input_size]
155    ///
156    /// # Returns
157    /// Output tensor of shape [batch, seq_len, hidden_size * num_directions]
158    pub fn forward_sequence(&self, input: &Tensor) -> Tensor {
159        let batch_size = input.dims()[0];
160        let seq_len = input.dims()[1];
161        let input_size = input.dims()[2];
162        
163        let hidden_size = self.cell.hidden_size;
164        let num_directions = if self.bidirectional { 2 } else { 1 };
165        
166        // Initialize hidden and cell states
167        let mut h = Tensor::zeros(&[batch_size, hidden_size]);
168        let mut c = Tensor::zeros(&[batch_size, hidden_size]);
169        
170        let input_data = input.data_f32();
171        let mut output_data = vec![0.0f32; batch_size * seq_len * hidden_size * num_directions];
172        
173        // Forward direction
174        for t in 0..seq_len {
175            // Extract input at time t
176            let mut x_t_data = vec![0.0f32; batch_size * input_size];
177            for b in 0..batch_size {
178                for i in 0..input_size {
179                    x_t_data[b * input_size + i] = 
180                        input_data[b * seq_len * input_size + t * input_size + i];
181                }
182            }
183            let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
184            
185            // LSTM cell forward
186            let (h_new, c_new) = self.cell.forward_cell(&x_t, &h, &c);
187            h = h_new;
188            c = c_new;
189            
190            // Store output
191            let h_data = h.data_f32();
192            for b in 0..batch_size {
193                for h_idx in 0..hidden_size {
194                    output_data[b * seq_len * hidden_size * num_directions + 
195                               t * hidden_size * num_directions + h_idx] = h_data[b * hidden_size + h_idx];
196                }
197            }
198        }
199        
200        // Backward direction (if bidirectional)
201        if self.bidirectional {
202            let mut h_back = Tensor::zeros(&[batch_size, hidden_size]);
203            let mut c_back = Tensor::zeros(&[batch_size, hidden_size]);
204            
205            for t in (0..seq_len).rev() {
206                let mut x_t_data = vec![0.0f32; batch_size * input_size];
207                for b in 0..batch_size {
208                    for i in 0..input_size {
209                        x_t_data[b * input_size + i] = 
210                            input_data[b * seq_len * input_size + t * input_size + i];
211                    }
212                }
213                let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
214                
215                let (h_new, c_new) = self.cell.forward_cell(&x_t, &h_back, &c_back);
216                h_back = h_new;
217                c_back = c_new;
218                
219                let h_data = h_back.data_f32();
220                for b in 0..batch_size {
221                    for h_idx in 0..hidden_size {
222                        output_data[b * seq_len * hidden_size * num_directions + 
223                                   t * hidden_size * num_directions + hidden_size + h_idx] = 
224                            h_data[b * hidden_size + h_idx];
225                    }
226                }
227            }
228        }
229        
230        Tensor::from_slice(
231            &output_data,
232            &[batch_size, seq_len, hidden_size * num_directions]
233        ).unwrap()
234    }
235}
236
237impl Module for LSTM {
238    fn forward(&self, input: &Tensor) -> Tensor {
239        self.forward_sequence(input)
240    }
241
242    fn parameters(&self) -> Vec<Tensor> {
243        self.cell.parameters()
244    }
245
246    fn train(&mut self) {
247        self.training = true;
248        self.cell.train();
249    }
250
251    fn eval(&mut self) {
252        self.training = false;
253        self.cell.eval();
254    }
255
256    fn is_training(&self) -> bool { self.training }
257}
258
259/// GRU Cell - Gated Recurrent Unit
260///
261/// Implements the GRU equations:
262/// - r_t = σ(W_ir * x_t + b_ir + W_hr * h_(t-1) + b_hr)  [reset gate]
263/// - z_t = σ(W_iz * x_t + b_iz + W_hz * h_(t-1) + b_hz)  [update gate]
264/// - n_t = tanh(W_in * x_t + b_in + r_t ⊙ (W_hn * h_(t-1) + b_hn))  [new gate]
265/// - h_t = (1 - z_t) ⊙ n_t + z_t ⊙ h_(t-1)  [hidden state]
266pub struct GRUCell {
267    input_size: usize,
268    hidden_size: usize,
269    
270    // Input-to-hidden weights (combined for all gates)
271    w_ih: Linear,
272    // Hidden-to-hidden weights (combined for all gates)
273    w_hh: Linear,
274    
275    training: bool,
276}
277
278impl GRUCell {
279    /// Create a new GRU cell
280    pub fn new(input_size: usize, hidden_size: usize) -> Self {
281        GRUCell {
282            input_size,
283            hidden_size,
284            // 3 * hidden_size for r, z, n gates
285            w_ih: Linear::new(input_size, 3 * hidden_size),
286            w_hh: Linear::new(hidden_size, 3 * hidden_size),
287            training: true,
288        }
289    }
290
291    /// Forward pass through GRU cell
292    pub fn forward_cell(&self, input: &Tensor, hidden: &Tensor) -> Tensor {
293        let batch_size = input.dims()[0];
294        
295        // Compute gates
296        let gi = self.w_ih.forward(input);
297        let gh = self.w_hh.forward(hidden);
298        
299        let gi_data = gi.data_f32();
300        let gh_data = gh.data_f32();
301        let h_data = hidden.data_f32();
302        
303        let mut new_hidden_data = vec![0.0f32; batch_size * self.hidden_size];
304        
305        for b in 0..batch_size {
306            for h in 0..self.hidden_size {
307                // Reset gate
308                let r = sigmoid(
309                    gi_data[b * 3 * self.hidden_size + h] + 
310                    gh_data[b * 3 * self.hidden_size + h]
311                );
312                
313                // Update gate
314                let z = sigmoid(
315                    gi_data[b * 3 * self.hidden_size + self.hidden_size + h] + 
316                    gh_data[b * 3 * self.hidden_size + self.hidden_size + h]
317                );
318                
319                // New gate
320                let n = tanh(
321                    gi_data[b * 3 * self.hidden_size + 2 * self.hidden_size + h] + 
322                    r * gh_data[b * 3 * self.hidden_size + 2 * self.hidden_size + h]
323                );
324                
325                // New hidden state
326                let h_prev = h_data[b * self.hidden_size + h];
327                new_hidden_data[b * self.hidden_size + h] = (1.0 - z) * n + z * h_prev;
328            }
329        }
330        
331        Tensor::from_slice(&new_hidden_data, &[batch_size, self.hidden_size]).unwrap()
332    }
333}
334
335impl Module for GRUCell {
336    fn forward(&self, input: &Tensor) -> Tensor {
337        let batch_size = input.dims()[0];
338        let hidden = Tensor::zeros(&[batch_size, self.hidden_size]);
339        self.forward_cell(input, &hidden)
340    }
341
342    fn parameters(&self) -> Vec<Tensor> {
343        let mut params = self.w_ih.parameters();
344        params.extend(self.w_hh.parameters());
345        params
346    }
347
348    fn train(&mut self) { self.training = true; }
349    fn eval(&mut self) { self.training = false; }
350    fn is_training(&self) -> bool { self.training }
351}
352
353/// GRU Layer - processes entire sequences
354pub struct GRU {
355    cell: GRUCell,
356    num_layers: usize,
357    bidirectional: bool,
358    dropout: f32,
359    training: bool,
360}
361
362impl GRU {
363    /// Create a new GRU layer
364    pub fn new(
365        input_size: usize,
366        hidden_size: usize,
367        num_layers: usize,
368        bidirectional: bool,
369        dropout: f32,
370    ) -> Self {
371        GRU {
372            cell: GRUCell::new(input_size, hidden_size),
373            num_layers,
374            bidirectional,
375            dropout,
376            training: true,
377        }
378    }
379
380    /// Forward pass through GRU
381    pub fn forward_sequence(&self, input: &Tensor) -> Tensor {
382        let batch_size = input.dims()[0];
383        let seq_len = input.dims()[1];
384        let input_size = input.dims()[2];
385        
386        let hidden_size = self.cell.hidden_size;
387        let num_directions = if self.bidirectional { 2 } else { 1 };
388        
389        let mut h = Tensor::zeros(&[batch_size, hidden_size]);
390        
391        let input_data = input.data_f32();
392        let mut output_data = vec![0.0f32; batch_size * seq_len * hidden_size * num_directions];
393        
394        // Forward direction
395        for t in 0..seq_len {
396            let mut x_t_data = vec![0.0f32; batch_size * input_size];
397            for b in 0..batch_size {
398                for i in 0..input_size {
399                    x_t_data[b * input_size + i] = 
400                        input_data[b * seq_len * input_size + t * input_size + i];
401                }
402            }
403            let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
404            
405            h = self.cell.forward_cell(&x_t, &h);
406            
407            let h_data = h.data_f32();
408            for b in 0..batch_size {
409                for h_idx in 0..hidden_size {
410                    output_data[b * seq_len * hidden_size * num_directions + 
411                               t * hidden_size * num_directions + h_idx] = h_data[b * hidden_size + h_idx];
412                }
413            }
414        }
415        
416        // Backward direction (if bidirectional)
417        if self.bidirectional {
418            let mut h_back = Tensor::zeros(&[batch_size, hidden_size]);
419            
420            for t in (0..seq_len).rev() {
421                let mut x_t_data = vec![0.0f32; batch_size * input_size];
422                for b in 0..batch_size {
423                    for i in 0..input_size {
424                        x_t_data[b * input_size + i] = 
425                            input_data[b * seq_len * input_size + t * input_size + i];
426                    }
427                }
428                let x_t = Tensor::from_slice(&x_t_data, &[batch_size, input_size]).unwrap();
429                
430                h_back = self.cell.forward_cell(&x_t, &h_back);
431                
432                let h_data = h_back.data_f32();
433                for b in 0..batch_size {
434                    for h_idx in 0..hidden_size {
435                        output_data[b * seq_len * hidden_size * num_directions + 
436                                   t * hidden_size * num_directions + hidden_size + h_idx] = 
437                            h_data[b * hidden_size + h_idx];
438                    }
439                }
440            }
441        }
442        
443        Tensor::from_slice(
444            &output_data,
445            &[batch_size, seq_len, hidden_size * num_directions]
446        ).unwrap()
447    }
448}
449
450impl Module for GRU {
451    fn forward(&self, input: &Tensor) -> Tensor {
452        self.forward_sequence(input)
453    }
454
455    fn parameters(&self) -> Vec<Tensor> {
456        self.cell.parameters()
457    }
458
459    fn train(&mut self) {
460        self.training = true;
461        self.cell.train();
462    }
463
464    fn eval(&mut self) {
465        self.training = false;
466        self.cell.eval();
467    }
468
469    fn is_training(&self) -> bool { self.training }
470}
471
472// Helper functions
473#[inline]
474fn sigmoid(x: f32) -> f32 {
475    1.0 / (1.0 + (-x).exp())
476}
477
478#[inline]
479fn tanh(x: f32) -> f32 {
480    x.tanh()
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_lstm_cell() {
489        let cell = LSTMCell::new(10, 20);
490        let input = Tensor::randn(&[2, 10]);
491        let hidden = Tensor::zeros(&[2, 20]);
492        let cell_state = Tensor::zeros(&[2, 20]);
493        
494        let (h, c) = cell.forward_cell(&input, &hidden, &cell_state);
495        
496        assert_eq!(h.dims(), &[2, 20]);
497        assert_eq!(c.dims(), &[2, 20]);
498    }
499
500    #[test]
501    fn test_lstm_sequence() {
502        let lstm = LSTM::new(10, 20, 1, false, 0.0);
503        let input = Tensor::randn(&[2, 5, 10]); // [batch, seq, features]
504        
505        let output = lstm.forward_sequence(&input);
506        
507        assert_eq!(output.dims(), &[2, 5, 20]);
508    }
509
510    #[test]
511    fn test_lstm_bidirectional() {
512        let lstm = LSTM::new(10, 20, 1, true, 0.0);
513        let input = Tensor::randn(&[2, 5, 10]);
514        
515        let output = lstm.forward_sequence(&input);
516        
517        assert_eq!(output.dims(), &[2, 5, 40]); // 20 * 2 directions
518    }
519
520    #[test]
521    fn test_gru_cell() {
522        let cell = GRUCell::new(10, 20);
523        let input = Tensor::randn(&[2, 10]);
524        let hidden = Tensor::zeros(&[2, 20]);
525        
526        let h = cell.forward_cell(&input, &hidden);
527        
528        assert_eq!(h.dims(), &[2, 20]);
529    }
530
531    #[test]
532    fn test_gru_sequence() {
533        let gru = GRU::new(10, 20, 1, false, 0.0);
534        let input = Tensor::randn(&[2, 5, 10]);
535        
536        let output = gru.forward_sequence(&input);
537        
538        assert_eq!(output.dims(), &[2, 5, 20]);
539    }
540
541    #[test]
542    fn test_gru_bidirectional() {
543        let gru = GRU::new(10, 20, 1, true, 0.0);
544        let input = Tensor::randn(&[2, 5, 10]);
545        
546        let output = gru.forward_sequence(&input);
547        
548        assert_eq!(output.dims(), &[2, 5, 40]); // 20 * 2 directions
549    }
550}