Skip to main content

axonml_nn/layers/
rnn.rs

1//! Recurrent Neural Network Layers - RNN, LSTM, GRU
2//!
3//! Processes sequential data with recurrent connections.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{xavier_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// RNNCell
19// =============================================================================
20
21/// A single RNN cell.
22///
23/// h' = tanh(W_ih * x + b_ih + W_hh * h + b_hh)
24pub struct RNNCell {
25    /// Input-hidden weights.
26    pub weight_ih: Parameter,
27    /// Hidden-hidden weights.
28    pub weight_hh: Parameter,
29    /// Input-hidden bias.
30    pub bias_ih: Parameter,
31    /// Hidden-hidden bias.
32    pub bias_hh: Parameter,
33    /// Input size.
34    input_size: usize,
35    /// Hidden size.
36    hidden_size: usize,
37}
38
39impl RNNCell {
40    /// Creates a new RNNCell.
41    pub fn new(input_size: usize, hidden_size: usize) -> Self {
42        Self {
43            weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
44            weight_hh: Parameter::named(
45                "weight_hh",
46                xavier_uniform(hidden_size, hidden_size),
47                true,
48            ),
49            bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
50            bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
51            input_size,
52            hidden_size,
53        }
54    }
55
56    /// Returns the expected input size.
57    pub fn input_size(&self) -> usize {
58        self.input_size
59    }
60
61    /// Returns the hidden state size.
62    pub fn hidden_size(&self) -> usize {
63        self.hidden_size
64    }
65
66    /// Forward pass for a single time step.
67    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
68        let input_features = input.data().shape().last().copied().unwrap_or(0);
69        assert_eq!(
70            input_features, self.input_size,
71            "RNNCell: expected input size {}, got {}",
72            self.input_size, input_features
73        );
74        // x @ W_ih^T + b_ih
75        let weight_ih = self.weight_ih.variable();
76        let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
77        let ih = input.matmul(&weight_ih_t);
78        let bias_ih = self.bias_ih.variable();
79        let ih = ih.add_var(&bias_ih);
80
81        // h @ W_hh^T + b_hh
82        let weight_hh = self.weight_hh.variable();
83        let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
84        let hh = hidden.matmul(&weight_hh_t);
85        let bias_hh = self.bias_hh.variable();
86        let hh = hh.add_var(&bias_hh);
87
88        // tanh(ih + hh)
89        ih.add_var(&hh).tanh()
90    }
91}
92
93impl Module for RNNCell {
94    fn forward(&self, input: &Variable) -> Variable {
95        // Initialize hidden state to zeros
96        let batch_size = input.shape()[0];
97        let hidden = Variable::new(
98            zeros(&[batch_size, self.hidden_size]),
99            input.requires_grad(),
100        );
101        self.forward_step(input, &hidden)
102    }
103
104    fn parameters(&self) -> Vec<Parameter> {
105        vec![
106            self.weight_ih.clone(),
107            self.weight_hh.clone(),
108            self.bias_ih.clone(),
109            self.bias_hh.clone(),
110        ]
111    }
112
113    fn named_parameters(&self) -> HashMap<String, Parameter> {
114        let mut params = HashMap::new();
115        params.insert("weight_ih".to_string(), self.weight_ih.clone());
116        params.insert("weight_hh".to_string(), self.weight_hh.clone());
117        params.insert("bias_ih".to_string(), self.bias_ih.clone());
118        params.insert("bias_hh".to_string(), self.bias_hh.clone());
119        params
120    }
121
122    fn name(&self) -> &'static str {
123        "RNNCell"
124    }
125}
126
127// =============================================================================
128// RNN
129// =============================================================================
130
131/// Multi-layer RNN.
132///
133/// Processes sequences through stacked RNN layers.
134pub struct RNN {
135    /// RNN cells for each layer.
136    cells: Vec<RNNCell>,
137    /// Input size.
138    input_size: usize,
139    /// Hidden size.
140    hidden_size: usize,
141    /// Number of layers.
142    num_layers: usize,
143    /// Batch first flag.
144    batch_first: bool,
145}
146
147impl RNN {
148    /// Creates a new multi-layer RNN.
149    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
150        Self::with_options(input_size, hidden_size, num_layers, true)
151    }
152
153    /// Creates an RNN with all options.
154    pub fn with_options(
155        input_size: usize,
156        hidden_size: usize,
157        num_layers: usize,
158        batch_first: bool,
159    ) -> Self {
160        let mut cells = Vec::with_capacity(num_layers);
161
162        // First layer takes input_size
163        cells.push(RNNCell::new(input_size, hidden_size));
164
165        // Subsequent layers take hidden_size
166        for _ in 1..num_layers {
167            cells.push(RNNCell::new(hidden_size, hidden_size));
168        }
169
170        Self {
171            cells,
172            input_size,
173            hidden_size,
174            num_layers,
175            batch_first,
176        }
177    }
178}
179
180impl Module for RNN {
181    fn forward(&self, input: &Variable) -> Variable {
182        let shape = input.shape();
183        let (batch_size, seq_len, _) = if self.batch_first {
184            (shape[0], shape[1], shape[2])
185        } else {
186            (shape[1], shape[0], shape[2])
187        };
188
189        // Initialize hidden states
190        let mut hiddens: Vec<Variable> = (0..self.num_layers)
191            .map(|_| {
192                Variable::new(
193                    zeros(&[batch_size, self.hidden_size]),
194                    input.requires_grad(),
195                )
196            })
197            .collect();
198
199        // Process each time step
200        let input_data = input.data();
201        let mut outputs = Vec::with_capacity(seq_len);
202
203        for t in 0..seq_len {
204            // Extract input at time t
205            let t_input = if self.batch_first {
206                // [batch, seq, features] -> extract [batch, features] at t
207                let mut slice_data = vec![0.0f32; batch_size * self.input_size];
208                let input_vec = input_data.to_vec();
209                for b in 0..batch_size {
210                    for f in 0..self.input_size {
211                        let src_idx = b * seq_len * self.input_size + t * self.input_size + f;
212                        let dst_idx = b * self.input_size + f;
213                        slice_data[dst_idx] = input_vec[src_idx];
214                    }
215                }
216                Variable::new(
217                    Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
218                    input.requires_grad(),
219                )
220            } else {
221                // [seq, batch, features] -> extract [batch, features] at t
222                let mut slice_data = vec![0.0f32; batch_size * self.input_size];
223                let input_vec = input_data.to_vec();
224                for b in 0..batch_size {
225                    for f in 0..self.input_size {
226                        let src_idx = t * batch_size * self.input_size + b * self.input_size + f;
227                        let dst_idx = b * self.input_size + f;
228                        slice_data[dst_idx] = input_vec[src_idx];
229                    }
230                }
231                Variable::new(
232                    Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
233                    input.requires_grad(),
234                )
235            };
236
237            // Process through layers
238            let mut layer_input = t_input;
239            for (l, cell) in self.cells.iter().enumerate() {
240                hiddens[l] = cell.forward_step(&layer_input, &hiddens[l]);
241                layer_input = hiddens[l].clone();
242            }
243
244            outputs.push(hiddens[self.num_layers - 1].clone());
245        }
246
247        // Stack outputs
248        let output_size = batch_size * seq_len * self.hidden_size;
249        let mut output_data = vec![0.0f32; output_size];
250
251        for (t, out) in outputs.iter().enumerate() {
252            let out_vec = out.data().to_vec();
253            for b in 0..batch_size {
254                for h in 0..self.hidden_size {
255                    let src_idx = b * self.hidden_size + h;
256                    let dst_idx = if self.batch_first {
257                        b * seq_len * self.hidden_size + t * self.hidden_size + h
258                    } else {
259                        t * batch_size * self.hidden_size + b * self.hidden_size + h
260                    };
261                    output_data[dst_idx] = out_vec[src_idx];
262                }
263            }
264        }
265
266        let output_shape = if self.batch_first {
267            vec![batch_size, seq_len, self.hidden_size]
268        } else {
269            vec![seq_len, batch_size, self.hidden_size]
270        };
271
272        Variable::new(
273            Tensor::from_vec(output_data, &output_shape).unwrap(),
274            input.requires_grad(),
275        )
276    }
277
278    fn parameters(&self) -> Vec<Parameter> {
279        self.cells.iter().flat_map(|c| c.parameters()).collect()
280    }
281
282    fn name(&self) -> &'static str {
283        "RNN"
284    }
285}
286
287// =============================================================================
288// LSTMCell
289// =============================================================================
290
291/// A single LSTM cell.
292pub struct LSTMCell {
293    /// Input-hidden weights for all gates.
294    pub weight_ih: Parameter,
295    /// Hidden-hidden weights for all gates.
296    pub weight_hh: Parameter,
297    /// Input-hidden bias for all gates.
298    pub bias_ih: Parameter,
299    /// Hidden-hidden bias for all gates.
300    pub bias_hh: Parameter,
301    /// Input size.
302    input_size: usize,
303    /// Hidden size.
304    hidden_size: usize,
305}
306
307impl LSTMCell {
308    /// Creates a new LSTMCell.
309    pub fn new(input_size: usize, hidden_size: usize) -> Self {
310        // LSTM has 4 gates, so weight size is 4*hidden_size
311        Self {
312            weight_ih: Parameter::named(
313                "weight_ih",
314                xavier_uniform(input_size, 4 * hidden_size),
315                true,
316            ),
317            weight_hh: Parameter::named(
318                "weight_hh",
319                xavier_uniform(hidden_size, 4 * hidden_size),
320                true,
321            ),
322            bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
323            bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
324            input_size,
325            hidden_size,
326        }
327    }
328
329    /// Returns the expected input size.
330    pub fn input_size(&self) -> usize {
331        self.input_size
332    }
333
334    /// Returns the hidden state size.
335    pub fn hidden_size(&self) -> usize {
336        self.hidden_size
337    }
338
339    /// Forward pass returning (h', c').
340    pub fn forward_step(
341        &self,
342        input: &Variable,
343        hx: &(Variable, Variable),
344    ) -> (Variable, Variable) {
345        let input_features = input.data().shape().last().copied().unwrap_or(0);
346        assert_eq!(
347            input_features, self.input_size,
348            "LSTMCell: expected input size {}, got {}",
349            self.input_size, input_features
350        );
351
352        let (h, c) = hx;
353
354        // Compute all gates at once (x @ W^T + b)
355        let weight_ih = self.weight_ih.variable();
356        let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
357        let ih = input.matmul(&weight_ih_t);
358        let bias_ih = self.bias_ih.variable();
359        let ih = ih.add_var(&bias_ih);
360
361        let weight_hh = self.weight_hh.variable();
362        let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
363        let hh = h.matmul(&weight_hh_t);
364        let bias_hh = self.bias_hh.variable();
365        let hh = hh.add_var(&bias_hh);
366
367        let gates = ih.add_var(&hh);
368        let gates_vec = gates.data().to_vec();
369        let batch_size = input.shape()[0];
370
371        // Split into 4 gates: i, f, g, o
372        let mut i_data = vec![0.0f32; batch_size * self.hidden_size];
373        let mut f_data = vec![0.0f32; batch_size * self.hidden_size];
374        let mut g_data = vec![0.0f32; batch_size * self.hidden_size];
375        let mut o_data = vec![0.0f32; batch_size * self.hidden_size];
376
377        for b in 0..batch_size {
378            for j in 0..self.hidden_size {
379                let base = b * 4 * self.hidden_size;
380                i_data[b * self.hidden_size + j] = gates_vec[base + j];
381                f_data[b * self.hidden_size + j] = gates_vec[base + self.hidden_size + j];
382                g_data[b * self.hidden_size + j] = gates_vec[base + 2 * self.hidden_size + j];
383                o_data[b * self.hidden_size + j] = gates_vec[base + 3 * self.hidden_size + j];
384            }
385        }
386
387        let i = Variable::new(
388            Tensor::from_vec(i_data, &[batch_size, self.hidden_size]).unwrap(),
389            input.requires_grad(),
390        )
391        .sigmoid();
392        let f = Variable::new(
393            Tensor::from_vec(f_data, &[batch_size, self.hidden_size]).unwrap(),
394            input.requires_grad(),
395        )
396        .sigmoid();
397        let g = Variable::new(
398            Tensor::from_vec(g_data, &[batch_size, self.hidden_size]).unwrap(),
399            input.requires_grad(),
400        )
401        .tanh();
402        let o = Variable::new(
403            Tensor::from_vec(o_data, &[batch_size, self.hidden_size]).unwrap(),
404            input.requires_grad(),
405        )
406        .sigmoid();
407
408        // c' = f * c + i * g
409        let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
410
411        // h' = o * tanh(c')
412        let h_new = o.mul_var(&c_new.tanh());
413
414        (h_new, c_new)
415    }
416}
417
418impl Module for LSTMCell {
419    fn forward(&self, input: &Variable) -> Variable {
420        let batch_size = input.shape()[0];
421        let h = Variable::new(
422            zeros(&[batch_size, self.hidden_size]),
423            input.requires_grad(),
424        );
425        let c = Variable::new(
426            zeros(&[batch_size, self.hidden_size]),
427            input.requires_grad(),
428        );
429        let (h_new, _) = self.forward_step(input, &(h, c));
430        h_new
431    }
432
433    fn parameters(&self) -> Vec<Parameter> {
434        vec![
435            self.weight_ih.clone(),
436            self.weight_hh.clone(),
437            self.bias_ih.clone(),
438            self.bias_hh.clone(),
439        ]
440    }
441
442    fn name(&self) -> &'static str {
443        "LSTMCell"
444    }
445}
446
447// =============================================================================
448// LSTM
449// =============================================================================
450
451/// Multi-layer LSTM.
452pub struct LSTM {
453    /// LSTM cells for each layer.
454    cells: Vec<LSTMCell>,
455    /// Input size.
456    input_size: usize,
457    /// Hidden size.
458    hidden_size: usize,
459    /// Number of layers.
460    num_layers: usize,
461    /// Batch first flag.
462    batch_first: bool,
463}
464
465impl LSTM {
466    /// Creates a new multi-layer LSTM.
467    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
468        Self::with_options(input_size, hidden_size, num_layers, true)
469    }
470
471    /// Creates an LSTM with all options.
472    pub fn with_options(
473        input_size: usize,
474        hidden_size: usize,
475        num_layers: usize,
476        batch_first: bool,
477    ) -> Self {
478        let mut cells = Vec::with_capacity(num_layers);
479        cells.push(LSTMCell::new(input_size, hidden_size));
480        for _ in 1..num_layers {
481            cells.push(LSTMCell::new(hidden_size, hidden_size));
482        }
483
484        Self {
485            cells,
486            input_size,
487            hidden_size,
488            num_layers,
489            batch_first,
490        }
491    }
492
493    /// Returns the expected input size.
494    pub fn input_size(&self) -> usize {
495        self.input_size
496    }
497
498    /// Returns the hidden state size.
499    pub fn hidden_size(&self) -> usize {
500        self.hidden_size
501    }
502
503    /// Returns the number of layers.
504    pub fn num_layers(&self) -> usize {
505        self.num_layers
506    }
507}
508
509impl Module for LSTM {
510    fn forward(&self, input: &Variable) -> Variable {
511        // Similar to RNN forward but using LSTM cells
512        // For brevity, implementing a simplified version
513        let shape = input.shape();
514        let (batch_size, seq_len, input_features) = if self.batch_first {
515            (shape[0], shape[1], shape[2])
516        } else {
517            (shape[1], shape[0], shape[2])
518        };
519
520        let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
521            .map(|_| {
522                (
523                    Variable::new(
524                        zeros(&[batch_size, self.hidden_size]),
525                        input.requires_grad(),
526                    ),
527                    Variable::new(
528                        zeros(&[batch_size, self.hidden_size]),
529                        input.requires_grad(),
530                    ),
531                )
532            })
533            .collect();
534
535        let input_data = input.data();
536        let input_vec = input_data.to_vec();
537        let mut outputs = Vec::with_capacity(seq_len);
538
539        for t in 0..seq_len {
540            let mut slice_data = vec![0.0f32; batch_size * input_features];
541            for b in 0..batch_size {
542                for f in 0..input_features {
543                    let src_idx = if self.batch_first {
544                        b * seq_len * input_features + t * input_features + f
545                    } else {
546                        t * batch_size * input_features + b * input_features + f
547                    };
548                    slice_data[b * input_features + f] = input_vec[src_idx];
549                }
550            }
551
552            // Input slice always has input_features dimensions
553            let mut layer_input = Variable::new(
554                Tensor::from_vec(slice_data.clone(), &[batch_size, input_features]).unwrap(),
555                input.requires_grad(),
556            );
557
558            for (l, cell) in self.cells.iter().enumerate() {
559                // Resize input if needed for subsequent layers
560                if l > 0 {
561                    layer_input = states[l - 1].0.clone();
562                }
563                states[l] = cell.forward_step(&layer_input, &states[l]);
564            }
565
566            outputs.push(states[self.num_layers - 1].0.clone());
567        }
568
569        // Stack outputs
570        let mut output_data = vec![0.0f32; batch_size * seq_len * self.hidden_size];
571        for (t, out) in outputs.iter().enumerate() {
572            let out_vec = out.data().to_vec();
573            for b in 0..batch_size {
574                for h in 0..self.hidden_size {
575                    let dst_idx = if self.batch_first {
576                        b * seq_len * self.hidden_size + t * self.hidden_size + h
577                    } else {
578                        t * batch_size * self.hidden_size + b * self.hidden_size + h
579                    };
580                    output_data[dst_idx] = out_vec[b * self.hidden_size + h];
581                }
582            }
583        }
584
585        let output_shape = if self.batch_first {
586            vec![batch_size, seq_len, self.hidden_size]
587        } else {
588            vec![seq_len, batch_size, self.hidden_size]
589        };
590
591        Variable::new(
592            Tensor::from_vec(output_data, &output_shape).unwrap(),
593            input.requires_grad(),
594        )
595    }
596
597    fn parameters(&self) -> Vec<Parameter> {
598        self.cells.iter().flat_map(|c| c.parameters()).collect()
599    }
600
601    fn name(&self) -> &'static str {
602        "LSTM"
603    }
604}
605
606// =============================================================================
607// GRUCell and GRU
608// =============================================================================
609
610/// A single GRU cell.
611///
612/// h' = (1 - z) * n + z * h
613/// where:
614///   r = sigmoid(W_ir * x + b_ir + W_hr * h + b_hr)  (reset gate)
615///   z = sigmoid(W_iz * x + b_iz + W_hz * h + b_hz)  (update gate)
616///   n = tanh(W_in * x + b_in + r * (W_hn * h + b_hn))  (new gate)
617pub struct GRUCell {
618    /// Input-hidden weights for all gates (reset, update, new).
619    pub weight_ih: Parameter,
620    /// Hidden-hidden weights for all gates (reset, update, new).
621    pub weight_hh: Parameter,
622    /// Input-hidden bias for all gates.
623    pub bias_ih: Parameter,
624    /// Hidden-hidden bias for all gates.
625    pub bias_hh: Parameter,
626    /// Input size.
627    input_size: usize,
628    /// Hidden size.
629    hidden_size: usize,
630}
631
632impl GRUCell {
633    /// Creates a new GRU cell.
634    pub fn new(input_size: usize, hidden_size: usize) -> Self {
635        Self {
636            weight_ih: Parameter::named(
637                "weight_ih",
638                xavier_uniform(input_size, 3 * hidden_size),
639                true,
640            ),
641            weight_hh: Parameter::named(
642                "weight_hh",
643                xavier_uniform(hidden_size, 3 * hidden_size),
644                true,
645            ),
646            bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
647            bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
648            input_size,
649            hidden_size,
650        }
651    }
652
653    /// Returns the expected input size.
654    pub fn input_size(&self) -> usize {
655        self.input_size
656    }
657
658    /// Returns the hidden state size.
659    pub fn hidden_size(&self) -> usize {
660        self.hidden_size
661    }
662}
663
664impl GRUCell {
665    /// Forward pass for a single time step with explicit hidden state.
666    ///
667    /// GRU equations:
668    /// r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)
669    /// z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)
670    /// n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
671    /// h_t = (1 - z_t) * n_t + z_t * h_{t-1}
672    ///
673    /// All computations use Variable operations for proper gradient flow.
674    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
675        let batch_size = input.shape()[0];
676        let hidden_size = self.hidden_size;
677
678        // Get weight matrices
679        let weight_ih = self.weight_ih.variable();
680        let weight_hh = self.weight_hh.variable();
681        let bias_ih = self.bias_ih.variable();
682        let bias_hh = self.bias_hh.variable();
683
684        // Compute input transformation: x @ W_ih^T + b_ih
685        // Shape: [batch, 3*hidden_size]
686        let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
687        let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
688
689        // Compute hidden transformation: h @ W_hh^T + b_hh
690        // Shape: [batch, 3*hidden_size]
691        let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
692        let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
693
694        // Use narrow to split into gates (preserves gradient flow)
695        // Each gate slice: [batch, hidden_size]
696        let ih_r = ih.narrow(1, 0, hidden_size);
697        let ih_z = ih.narrow(1, hidden_size, hidden_size);
698        let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
699
700        let hh_r = hh.narrow(1, 0, hidden_size);
701        let hh_z = hh.narrow(1, hidden_size, hidden_size);
702        let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
703
704        // Compute gates using Variable operations for gradient flow
705        // r = sigmoid(ih_r + hh_r)
706        let r = ih_r.add_var(&hh_r).sigmoid();
707
708        // z = sigmoid(ih_z + hh_z)
709        let z = ih_z.add_var(&hh_z).sigmoid();
710
711        // n = tanh(ih_n + r * hh_n)
712        let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
713
714        // h_new = (1 - z) * n + z * h_prev
715        // Create ones for (1 - z)
716        let shape = [batch_size, hidden_size];
717        let ones = Variable::new(
718            Tensor::from_vec(vec![1.0f32; batch_size * hidden_size], &shape).unwrap(),
719            false,
720        );
721        let one_minus_z = ones.sub_var(&z);
722
723        // h_new = one_minus_z * n + z * h_prev
724        one_minus_z.mul_var(&n).add_var(&z.mul_var(hidden))
725    }
726}
727
728
729impl Module for GRUCell {
730    fn forward(&self, input: &Variable) -> Variable {
731        let batch_size = input.shape()[0];
732
733        // Initialize hidden state to zeros
734        let hidden = Variable::new(
735            zeros(&[batch_size, self.hidden_size]),
736            input.requires_grad(),
737        );
738
739        self.forward_step(input, &hidden)
740    }
741
742    fn parameters(&self) -> Vec<Parameter> {
743        vec![
744            self.weight_ih.clone(),
745            self.weight_hh.clone(),
746            self.bias_ih.clone(),
747            self.bias_hh.clone(),
748        ]
749    }
750
751    fn name(&self) -> &'static str {
752        "GRUCell"
753    }
754}
755
756/// Multi-layer GRU.
757pub struct GRU {
758    /// GRU cells for each layer.
759    cells: Vec<GRUCell>,
760    /// Hidden state size.
761    hidden_size: usize,
762    /// Number of layers.
763    num_layers: usize,
764    /// If true, input is (batch, seq, features), else (seq, batch, features).
765    batch_first: bool,
766}
767
768impl GRU {
769    /// Creates a new multi-layer GRU.
770    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
771        let mut cells = Vec::with_capacity(num_layers);
772        cells.push(GRUCell::new(input_size, hidden_size));
773        for _ in 1..num_layers {
774            cells.push(GRUCell::new(hidden_size, hidden_size));
775        }
776        Self {
777            cells,
778            hidden_size,
779            num_layers,
780            batch_first: true,
781        }
782    }
783
784    /// Returns the hidden state size.
785    pub fn hidden_size(&self) -> usize {
786        self.hidden_size
787    }
788
789    /// Returns the number of layers.
790    pub fn num_layers(&self) -> usize {
791        self.num_layers
792    }
793}
794
795impl Module for GRU {
796    fn forward(&self, input: &Variable) -> Variable {
797        let shape = input.shape();
798        let (batch_size, seq_len, _input_size) = if self.batch_first {
799            (shape[0], shape[1], shape[2])
800        } else {
801            (shape[1], shape[0], shape[2])
802        };
803
804        // Initialize hidden states for all layers as Variables (with gradients)
805        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
806            .map(|_| {
807                Variable::new(
808                    zeros(&[batch_size, self.hidden_size]),
809                    input.requires_grad(),
810                )
811            })
812            .collect();
813
814        // Collect output Variables for each time step
815        let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
816
817        // Process each time step
818        for t in 0..seq_len {
819            // Extract input for this time step using narrow (preserves gradients)
820            // input shape: [batch, seq, features]
821            // narrow to [batch, 1, features], then reshape to [batch, features]
822            // narrow gives [batch, 1, features], reshape to [batch, features]
823            let narrowed = input.narrow(1, t, 1);
824            let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
825
826            // Process through each layer
827            let mut layer_input = step_input;
828
829            for (layer_idx, cell) in self.cells.iter().enumerate() {
830                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
831
832                // Update hidden state for this layer (keeps gradient chain)
833                hidden_states[layer_idx] = new_hidden.clone();
834
835                // Output of this layer becomes input to next layer
836                layer_input = new_hidden;
837            }
838
839            // Store output from last layer for this time step
840            output_vars.push(layer_input);
841        }
842
843        // Stack outputs along the time dimension
844        // Each output_var has shape [batch, hidden_size]
845        // We need to combine them into [batch, seq, hidden_size]
846        self.stack_outputs(&output_vars, batch_size, seq_len)
847    }
848
849    fn parameters(&self) -> Vec<Parameter> {
850        self.cells.iter().flat_map(|c| c.parameters()).collect()
851    }
852
853    fn name(&self) -> &'static str {
854        "GRU"
855    }
856}
857
858impl GRU {
859    /// Forward pass that returns the mean of all hidden states.
860    /// This is equivalent to processing then mean pooling, but with proper gradient flow.
861    pub fn forward_mean(&self, input: &Variable) -> Variable {
862        let shape = input.shape();
863        let (batch_size, seq_len, _input_size) = if self.batch_first {
864            (shape[0], shape[1], shape[2])
865        } else {
866            (shape[1], shape[0], shape[2])
867        };
868
869        // Initialize hidden states for all layers as Variables (with gradients)
870        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
871            .map(|_| {
872                Variable::new(
873                    zeros(&[batch_size, self.hidden_size]),
874                    input.requires_grad(),
875                )
876            })
877            .collect();
878
879        // Accumulator for mean of outputs
880        let mut output_sum: Option<Variable> = None;
881
882        // Process each time step
883        for t in 0..seq_len {
884            // Extract input for this time step using narrow (preserves gradients)
885            // narrow gives [batch, 1, features], reshape to [batch, features]
886            let narrowed = input.narrow(1, t, 1);
887            let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
888
889            // Process through each layer
890            let mut layer_input = step_input;
891
892            for (layer_idx, cell) in self.cells.iter().enumerate() {
893                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
894                hidden_states[layer_idx] = new_hidden.clone();
895                layer_input = new_hidden;
896            }
897
898            // Accumulate output (last layer's hidden state)
899            output_sum = Some(match output_sum {
900                None => layer_input,
901                Some(acc) => acc.add_var(&layer_input),
902            });
903        }
904
905        // Return mean of all hidden states
906        match output_sum {
907            Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
908            None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
909        }
910    }
911
912    /// Forward pass that returns the last hidden state.
913    /// Good for sequence classification with proper gradient flow.
914    pub fn forward_last(&self, input: &Variable) -> Variable {
915        let shape = input.shape();
916        let (batch_size, seq_len, _input_size) = if self.batch_first {
917            (shape[0], shape[1], shape[2])
918        } else {
919            (shape[1], shape[0], shape[2])
920        };
921
922        // Initialize hidden states for all layers
923        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
924            .map(|_| {
925                Variable::new(
926                    zeros(&[batch_size, self.hidden_size]),
927                    input.requires_grad(),
928                )
929            })
930            .collect();
931
932        // Process each time step
933        for t in 0..seq_len {
934            // narrow gives [batch, 1, features], reshape to [batch, features]
935            let narrowed = input.narrow(1, t, 1);
936            let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
937
938            let mut layer_input = step_input;
939
940            for (layer_idx, cell) in self.cells.iter().enumerate() {
941                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
942                hidden_states[layer_idx] = new_hidden.clone();
943                layer_input = new_hidden;
944            }
945        }
946
947        // Return last hidden state from last layer
948        hidden_states.pop().unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
949    }
950
951    /// Stack output Variables into a single [batch, seq, hidden] tensor.
952    /// Note: This creates a new tensor without gradient connections to individual timesteps.
953    /// For gradient flow, use forward_mean() or forward_last() instead.
954    fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, seq_len: usize) -> Variable {
955        if outputs.is_empty() {
956            return Variable::new(
957                zeros(&[batch_size, 0, self.hidden_size]),
958                false,
959            );
960        }
961
962        let output_shape = [batch_size, seq_len, self.hidden_size];
963        let requires_grad = outputs.iter().any(|o| o.requires_grad());
964
965        let mut stacked_data = vec![0.0f32; batch_size * seq_len * self.hidden_size];
966        for (t, out) in outputs.iter().enumerate() {
967            let out_data = out.data().to_vec();
968            for b in 0..batch_size {
969                for h in 0..self.hidden_size {
970                    let idx = b * seq_len * self.hidden_size + t * self.hidden_size + h;
971                    stacked_data[idx] = out_data[b * self.hidden_size + h];
972                }
973            }
974        }
975
976        Variable::new(
977            Tensor::from_vec(stacked_data, &output_shape).unwrap(),
978            requires_grad,
979        )
980    }
981}
982
983// =============================================================================
984// Tests
985// =============================================================================
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990
991    #[test]
992    fn test_rnn_cell() {
993        let cell = RNNCell::new(10, 20);
994        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
995        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
996        let output = cell.forward_step(&input, &hidden);
997        assert_eq!(output.shape(), vec![2, 20]);
998    }
999
1000    #[test]
1001    fn test_rnn() {
1002        let rnn = RNN::new(10, 20, 2);
1003        let input = Variable::new(
1004            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1005            false,
1006        );
1007        let output = rnn.forward(&input);
1008        assert_eq!(output.shape(), vec![2, 5, 20]);
1009    }
1010
1011    #[test]
1012    fn test_lstm() {
1013        let lstm = LSTM::new(10, 20, 1);
1014        let input = Variable::new(
1015            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1016            false,
1017        );
1018        let output = lstm.forward(&input);
1019        assert_eq!(output.shape(), vec![2, 5, 20]);
1020    }
1021}