Skip to main content

axonml_nn/layers/
rnn.rs

1//! Recurrent Neural Network Layers - RNN, LSTM, GRU
2//!
3//! Processes sequential data with recurrent connections.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{xavier_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// RNNCell
19// =============================================================================
20
21/// A single RNN cell.
22///
23/// h' = tanh(W_ih * x + b_ih + W_hh * h + b_hh)
24pub struct RNNCell {
25    /// Input-hidden weights.
26    pub weight_ih: Parameter,
27    /// Hidden-hidden weights.
28    pub weight_hh: Parameter,
29    /// Input-hidden bias.
30    pub bias_ih: Parameter,
31    /// Hidden-hidden bias.
32    pub bias_hh: Parameter,
33    /// Input size.
34    input_size: usize,
35    /// Hidden size.
36    hidden_size: usize,
37}
38
39impl RNNCell {
40    /// Creates a new RNNCell.
41    pub fn new(input_size: usize, hidden_size: usize) -> Self {
42        Self {
43            weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
44            weight_hh: Parameter::named(
45                "weight_hh",
46                xavier_uniform(hidden_size, hidden_size),
47                true,
48            ),
49            bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
50            bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
51            input_size,
52            hidden_size,
53        }
54    }
55
56    /// Returns the expected input size.
57    pub fn input_size(&self) -> usize {
58        self.input_size
59    }
60
61    /// Returns the hidden state size.
62    pub fn hidden_size(&self) -> usize {
63        self.hidden_size
64    }
65
66    /// Forward pass for a single time step.
67    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
68        let input_features = input.data().shape().last().copied().unwrap_or(0);
69        assert_eq!(
70            input_features, self.input_size,
71            "RNNCell: expected input size {}, got {}",
72            self.input_size, input_features
73        );
74        // x @ W_ih^T + b_ih
75        let weight_ih = self.weight_ih.variable();
76        let weight_ih_t = weight_ih.transpose(0, 1);
77        let ih = input.matmul(&weight_ih_t);
78        let bias_ih = self.bias_ih.variable();
79        let ih = ih.add_var(&bias_ih);
80
81        // h @ W_hh^T + b_hh
82        let weight_hh = self.weight_hh.variable();
83        let weight_hh_t = weight_hh.transpose(0, 1);
84        let hh = hidden.matmul(&weight_hh_t);
85        let bias_hh = self.bias_hh.variable();
86        let hh = hh.add_var(&bias_hh);
87
88        // tanh(ih + hh)
89        ih.add_var(&hh).tanh()
90    }
91}
92
93impl Module for RNNCell {
94    fn forward(&self, input: &Variable) -> Variable {
95        // Initialize hidden state to zeros
96        let batch_size = input.shape()[0];
97        let hidden = Variable::new(
98            zeros(&[batch_size, self.hidden_size]),
99            input.requires_grad(),
100        );
101        self.forward_step(input, &hidden)
102    }
103
104    fn parameters(&self) -> Vec<Parameter> {
105        vec![
106            self.weight_ih.clone(),
107            self.weight_hh.clone(),
108            self.bias_ih.clone(),
109            self.bias_hh.clone(),
110        ]
111    }
112
113    fn named_parameters(&self) -> HashMap<String, Parameter> {
114        let mut params = HashMap::new();
115        params.insert("weight_ih".to_string(), self.weight_ih.clone());
116        params.insert("weight_hh".to_string(), self.weight_hh.clone());
117        params.insert("bias_ih".to_string(), self.bias_ih.clone());
118        params.insert("bias_hh".to_string(), self.bias_hh.clone());
119        params
120    }
121
122    fn name(&self) -> &'static str {
123        "RNNCell"
124    }
125}
126
127// =============================================================================
128// RNN
129// =============================================================================
130
131/// Multi-layer RNN.
132///
133/// Processes sequences through stacked RNN layers.
134pub struct RNN {
135    /// RNN cells for each layer.
136    cells: Vec<RNNCell>,
137    /// Input size.
138    input_size: usize,
139    /// Hidden size.
140    hidden_size: usize,
141    /// Number of layers.
142    num_layers: usize,
143    /// Batch first flag.
144    batch_first: bool,
145}
146
147impl RNN {
148    /// Creates a new multi-layer RNN.
149    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
150        Self::with_options(input_size, hidden_size, num_layers, true)
151    }
152
153    /// Creates an RNN with all options.
154    pub fn with_options(
155        input_size: usize,
156        hidden_size: usize,
157        num_layers: usize,
158        batch_first: bool,
159    ) -> Self {
160        let mut cells = Vec::with_capacity(num_layers);
161
162        // First layer takes input_size
163        cells.push(RNNCell::new(input_size, hidden_size));
164
165        // Subsequent layers take hidden_size
166        for _ in 1..num_layers {
167            cells.push(RNNCell::new(hidden_size, hidden_size));
168        }
169
170        Self {
171            cells,
172            input_size,
173            hidden_size,
174            num_layers,
175            batch_first,
176        }
177    }
178}
179
180impl Module for RNN {
181    fn forward(&self, input: &Variable) -> Variable {
182        let shape = input.shape();
183        let (batch_size, seq_len, _) = if self.batch_first {
184            (shape[0], shape[1], shape[2])
185        } else {
186            (shape[1], shape[0], shape[2])
187        };
188
189        // Initialize hidden states
190        let mut hiddens: Vec<Variable> = (0..self.num_layers)
191            .map(|_| {
192                Variable::new(
193                    zeros(&[batch_size, self.hidden_size]),
194                    input.requires_grad(),
195                )
196            })
197            .collect();
198
199        // Process each time step
200        let input_data = input.data();
201        let mut outputs = Vec::with_capacity(seq_len);
202
203        for t in 0..seq_len {
204            // Extract input at time t
205            let t_input = if self.batch_first {
206                // [batch, seq, features] -> extract [batch, features] at t
207                let mut slice_data = vec![0.0f32; batch_size * self.input_size];
208                let input_vec = input_data.to_vec();
209                for b in 0..batch_size {
210                    for f in 0..self.input_size {
211                        let src_idx = b * seq_len * self.input_size + t * self.input_size + f;
212                        let dst_idx = b * self.input_size + f;
213                        slice_data[dst_idx] = input_vec[src_idx];
214                    }
215                }
216                Variable::new(
217                    Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
218                    input.requires_grad(),
219                )
220            } else {
221                // [seq, batch, features] -> extract [batch, features] at t
222                let mut slice_data = vec![0.0f32; batch_size * self.input_size];
223                let input_vec = input_data.to_vec();
224                for b in 0..batch_size {
225                    for f in 0..self.input_size {
226                        let src_idx = t * batch_size * self.input_size + b * self.input_size + f;
227                        let dst_idx = b * self.input_size + f;
228                        slice_data[dst_idx] = input_vec[src_idx];
229                    }
230                }
231                Variable::new(
232                    Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
233                    input.requires_grad(),
234                )
235            };
236
237            // Process through layers
238            let mut layer_input = t_input;
239            for (l, cell) in self.cells.iter().enumerate() {
240                hiddens[l] = cell.forward_step(&layer_input, &hiddens[l]);
241                layer_input = hiddens[l].clone();
242            }
243
244            outputs.push(hiddens[self.num_layers - 1].clone());
245        }
246
247        // Stack outputs using graph-tracked cat (unsqueeze + cat along time dim)
248        let time_dim = if self.batch_first { 1 } else { 0 };
249        let unsqueezed: Vec<Variable> = outputs.iter()
250            .map(|o| o.unsqueeze(time_dim))
251            .collect();
252        let refs: Vec<&Variable> = unsqueezed.iter().collect();
253        Variable::cat(&refs, time_dim)
254    }
255
256    fn parameters(&self) -> Vec<Parameter> {
257        self.cells.iter().flat_map(|c| c.parameters()).collect()
258    }
259
260    fn name(&self) -> &'static str {
261        "RNN"
262    }
263}
264
265// =============================================================================
266// LSTMCell
267// =============================================================================
268
269/// A single LSTM cell.
270pub struct LSTMCell {
271    /// Input-hidden weights for all gates.
272    pub weight_ih: Parameter,
273    /// Hidden-hidden weights for all gates.
274    pub weight_hh: Parameter,
275    /// Input-hidden bias for all gates.
276    pub bias_ih: Parameter,
277    /// Hidden-hidden bias for all gates.
278    pub bias_hh: Parameter,
279    /// Input size.
280    input_size: usize,
281    /// Hidden size.
282    hidden_size: usize,
283}
284
285impl LSTMCell {
286    /// Creates a new LSTMCell.
287    pub fn new(input_size: usize, hidden_size: usize) -> Self {
288        // LSTM has 4 gates, so weight size is 4*hidden_size
289        Self {
290            weight_ih: Parameter::named(
291                "weight_ih",
292                xavier_uniform(input_size, 4 * hidden_size),
293                true,
294            ),
295            weight_hh: Parameter::named(
296                "weight_hh",
297                xavier_uniform(hidden_size, 4 * hidden_size),
298                true,
299            ),
300            bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
301            bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
302            input_size,
303            hidden_size,
304        }
305    }
306
307    /// Returns the expected input size.
308    pub fn input_size(&self) -> usize {
309        self.input_size
310    }
311
312    /// Returns the hidden state size.
313    pub fn hidden_size(&self) -> usize {
314        self.hidden_size
315    }
316
317    /// Forward pass returning (h', c').
318    pub fn forward_step(
319        &self,
320        input: &Variable,
321        hx: &(Variable, Variable),
322    ) -> (Variable, Variable) {
323        let input_features = input.data().shape().last().copied().unwrap_or(0);
324        assert_eq!(
325            input_features, self.input_size,
326            "LSTMCell: expected input size {}, got {}",
327            self.input_size, input_features
328        );
329
330        let (h, c) = hx;
331
332        // Compute all gates at once (x @ W^T + b)
333        let weight_ih = self.weight_ih.variable();
334        let weight_ih_t = weight_ih.transpose(0, 1);
335        let ih = input.matmul(&weight_ih_t);
336        let bias_ih = self.bias_ih.variable();
337        let ih = ih.add_var(&bias_ih);
338
339        let weight_hh = self.weight_hh.variable();
340        let weight_hh_t = weight_hh.transpose(0, 1);
341        let hh = h.matmul(&weight_hh_t);
342        let bias_hh = self.bias_hh.variable();
343        let hh = hh.add_var(&bias_hh);
344
345        let gates = ih.add_var(&hh);
346        let gates_vec = gates.data().to_vec();
347        let batch_size = input.shape()[0];
348
349        // Split into 4 gates: i, f, g, o
350        let mut i_data = vec![0.0f32; batch_size * self.hidden_size];
351        let mut f_data = vec![0.0f32; batch_size * self.hidden_size];
352        let mut g_data = vec![0.0f32; batch_size * self.hidden_size];
353        let mut o_data = vec![0.0f32; batch_size * self.hidden_size];
354
355        for b in 0..batch_size {
356            for j in 0..self.hidden_size {
357                let base = b * 4 * self.hidden_size;
358                i_data[b * self.hidden_size + j] = gates_vec[base + j];
359                f_data[b * self.hidden_size + j] = gates_vec[base + self.hidden_size + j];
360                g_data[b * self.hidden_size + j] = gates_vec[base + 2 * self.hidden_size + j];
361                o_data[b * self.hidden_size + j] = gates_vec[base + 3 * self.hidden_size + j];
362            }
363        }
364
365        let i = Variable::new(
366            Tensor::from_vec(i_data, &[batch_size, self.hidden_size]).unwrap(),
367            input.requires_grad(),
368        )
369        .sigmoid();
370        let f = Variable::new(
371            Tensor::from_vec(f_data, &[batch_size, self.hidden_size]).unwrap(),
372            input.requires_grad(),
373        )
374        .sigmoid();
375        let g = Variable::new(
376            Tensor::from_vec(g_data, &[batch_size, self.hidden_size]).unwrap(),
377            input.requires_grad(),
378        )
379        .tanh();
380        let o = Variable::new(
381            Tensor::from_vec(o_data, &[batch_size, self.hidden_size]).unwrap(),
382            input.requires_grad(),
383        )
384        .sigmoid();
385
386        // c' = f * c + i * g
387        let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
388
389        // h' = o * tanh(c')
390        let h_new = o.mul_var(&c_new.tanh());
391
392        (h_new, c_new)
393    }
394}
395
396impl Module for LSTMCell {
397    fn forward(&self, input: &Variable) -> Variable {
398        let batch_size = input.shape()[0];
399        let h = Variable::new(
400            zeros(&[batch_size, self.hidden_size]),
401            input.requires_grad(),
402        );
403        let c = Variable::new(
404            zeros(&[batch_size, self.hidden_size]),
405            input.requires_grad(),
406        );
407        let (h_new, _) = self.forward_step(input, &(h, c));
408        h_new
409    }
410
411    fn parameters(&self) -> Vec<Parameter> {
412        vec![
413            self.weight_ih.clone(),
414            self.weight_hh.clone(),
415            self.bias_ih.clone(),
416            self.bias_hh.clone(),
417        ]
418    }
419
420    fn named_parameters(&self) -> HashMap<String, Parameter> {
421        let mut params = HashMap::new();
422        params.insert("weight_ih".to_string(), self.weight_ih.clone());
423        params.insert("weight_hh".to_string(), self.weight_hh.clone());
424        params.insert("bias_ih".to_string(), self.bias_ih.clone());
425        params.insert("bias_hh".to_string(), self.bias_hh.clone());
426        params
427    }
428
429    fn name(&self) -> &'static str {
430        "LSTMCell"
431    }
432}
433
434// =============================================================================
435// LSTM
436// =============================================================================
437
438/// Multi-layer LSTM.
439pub struct LSTM {
440    /// LSTM cells for each layer.
441    cells: Vec<LSTMCell>,
442    /// Input size.
443    input_size: usize,
444    /// Hidden size.
445    hidden_size: usize,
446    /// Number of layers.
447    num_layers: usize,
448    /// Batch first flag.
449    batch_first: bool,
450}
451
452impl LSTM {
453    /// Creates a new multi-layer LSTM.
454    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
455        Self::with_options(input_size, hidden_size, num_layers, true)
456    }
457
458    /// Creates an LSTM with all options.
459    pub fn with_options(
460        input_size: usize,
461        hidden_size: usize,
462        num_layers: usize,
463        batch_first: bool,
464    ) -> Self {
465        let mut cells = Vec::with_capacity(num_layers);
466        cells.push(LSTMCell::new(input_size, hidden_size));
467        for _ in 1..num_layers {
468            cells.push(LSTMCell::new(hidden_size, hidden_size));
469        }
470
471        Self {
472            cells,
473            input_size,
474            hidden_size,
475            num_layers,
476            batch_first,
477        }
478    }
479
480    /// Returns the expected input size.
481    pub fn input_size(&self) -> usize {
482        self.input_size
483    }
484
485    /// Returns the hidden state size.
486    pub fn hidden_size(&self) -> usize {
487        self.hidden_size
488    }
489
490    /// Returns the number of layers.
491    pub fn num_layers(&self) -> usize {
492        self.num_layers
493    }
494}
495
496impl Module for LSTM {
497    fn forward(&self, input: &Variable) -> Variable {
498        // Similar to RNN forward but using LSTM cells
499        // For brevity, implementing a simplified version
500        let shape = input.shape();
501        let (batch_size, seq_len, input_features) = if self.batch_first {
502            (shape[0], shape[1], shape[2])
503        } else {
504            (shape[1], shape[0], shape[2])
505        };
506
507        let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
508            .map(|_| {
509                (
510                    Variable::new(
511                        zeros(&[batch_size, self.hidden_size]),
512                        input.requires_grad(),
513                    ),
514                    Variable::new(
515                        zeros(&[batch_size, self.hidden_size]),
516                        input.requires_grad(),
517                    ),
518                )
519            })
520            .collect();
521
522        let input_data = input.data();
523        let input_vec = input_data.to_vec();
524        let mut outputs = Vec::with_capacity(seq_len);
525
526        for t in 0..seq_len {
527            let mut slice_data = vec![0.0f32; batch_size * input_features];
528            for b in 0..batch_size {
529                for f in 0..input_features {
530                    let src_idx = if self.batch_first {
531                        b * seq_len * input_features + t * input_features + f
532                    } else {
533                        t * batch_size * input_features + b * input_features + f
534                    };
535                    slice_data[b * input_features + f] = input_vec[src_idx];
536                }
537            }
538
539            // Input slice always has input_features dimensions
540            let mut layer_input = Variable::new(
541                Tensor::from_vec(slice_data.clone(), &[batch_size, input_features]).unwrap(),
542                input.requires_grad(),
543            );
544
545            for (l, cell) in self.cells.iter().enumerate() {
546                // Resize input if needed for subsequent layers
547                if l > 0 {
548                    layer_input = states[l - 1].0.clone();
549                }
550                states[l] = cell.forward_step(&layer_input, &states[l]);
551            }
552
553            outputs.push(states[self.num_layers - 1].0.clone());
554        }
555
556        // Stack outputs using graph-tracked cat (unsqueeze + cat along time dim)
557        let time_dim = if self.batch_first { 1 } else { 0 };
558        let unsqueezed: Vec<Variable> = outputs.iter()
559            .map(|o| o.unsqueeze(time_dim))
560            .collect();
561        let refs: Vec<&Variable> = unsqueezed.iter().collect();
562        Variable::cat(&refs, time_dim)
563    }
564
565    fn parameters(&self) -> Vec<Parameter> {
566        self.cells.iter().flat_map(|c| c.parameters()).collect()
567    }
568
569    fn named_parameters(&self) -> HashMap<String, Parameter> {
570        let mut params = HashMap::new();
571        if self.cells.len() == 1 {
572            // Single layer: expose directly without cell index prefix
573            for (n, p) in self.cells[0].named_parameters() {
574                params.insert(n, p);
575            }
576        } else {
577            for (i, cell) in self.cells.iter().enumerate() {
578                for (n, p) in cell.named_parameters() {
579                    params.insert(format!("cells.{i}.{n}"), p);
580                }
581            }
582        }
583        params
584    }
585
586    fn name(&self) -> &'static str {
587        "LSTM"
588    }
589}
590
591// =============================================================================
592// GRUCell and GRU
593// =============================================================================
594
595/// A single GRU cell.
596///
597/// h' = (1 - z) * n + z * h
598/// where:
599///   r = sigmoid(W_ir * x + b_ir + W_hr * h + b_hr)  (reset gate)
600///   z = sigmoid(W_iz * x + b_iz + W_hz * h + b_hz)  (update gate)
601///   n = tanh(W_in * x + b_in + r * (W_hn * h + b_hn))  (new gate)
602pub struct GRUCell {
603    /// Input-hidden weights for all gates (reset, update, new).
604    pub weight_ih: Parameter,
605    /// Hidden-hidden weights for all gates (reset, update, new).
606    pub weight_hh: Parameter,
607    /// Input-hidden bias for all gates.
608    pub bias_ih: Parameter,
609    /// Hidden-hidden bias for all gates.
610    pub bias_hh: Parameter,
611    /// Input size.
612    input_size: usize,
613    /// Hidden size.
614    hidden_size: usize,
615}
616
617impl GRUCell {
618    /// Creates a new GRU cell.
619    pub fn new(input_size: usize, hidden_size: usize) -> Self {
620        Self {
621            weight_ih: Parameter::named(
622                "weight_ih",
623                xavier_uniform(input_size, 3 * hidden_size),
624                true,
625            ),
626            weight_hh: Parameter::named(
627                "weight_hh",
628                xavier_uniform(hidden_size, 3 * hidden_size),
629                true,
630            ),
631            bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
632            bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
633            input_size,
634            hidden_size,
635        }
636    }
637
638    /// Returns the expected input size.
639    pub fn input_size(&self) -> usize {
640        self.input_size
641    }
642
643    /// Returns the hidden state size.
644    pub fn hidden_size(&self) -> usize {
645        self.hidden_size
646    }
647}
648
649impl GRUCell {
650    /// Forward pass for a single time step with explicit hidden state.
651    ///
652    /// GRU equations:
653    /// r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)
654    /// z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)
655    /// n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
656    /// h_t = (1 - z_t) * n_t + z_t * h_{t-1}
657    ///
658    /// All computations use Variable operations for proper gradient flow.
659    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
660        let batch_size = input.shape()[0];
661        let hidden_size = self.hidden_size;
662
663        // Get weight matrices
664        let weight_ih = self.weight_ih.variable();
665        let weight_hh = self.weight_hh.variable();
666        let bias_ih = self.bias_ih.variable();
667        let bias_hh = self.bias_hh.variable();
668
669        // Compute input transformation: x @ W_ih^T + b_ih
670        // Shape: [batch, 3*hidden_size]
671        let weight_ih_t = weight_ih.transpose(0, 1);
672        let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
673
674        // Compute hidden transformation: h @ W_hh^T + b_hh
675        // Shape: [batch, 3*hidden_size]
676        let weight_hh_t = weight_hh.transpose(0, 1);
677        let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
678
679        // Use narrow to split into gates (preserves gradient flow)
680        // Each gate slice: [batch, hidden_size]
681        let ih_r = ih.narrow(1, 0, hidden_size);
682        let ih_z = ih.narrow(1, hidden_size, hidden_size);
683        let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
684
685        let hh_r = hh.narrow(1, 0, hidden_size);
686        let hh_z = hh.narrow(1, hidden_size, hidden_size);
687        let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
688
689        // Compute gates using Variable operations for gradient flow
690        // r = sigmoid(ih_r + hh_r)
691        let r = ih_r.add_var(&hh_r).sigmoid();
692
693        // z = sigmoid(ih_z + hh_z)
694        let z = ih_z.add_var(&hh_z).sigmoid();
695
696        // n = tanh(ih_n + r * hh_n)
697        let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
698
699        // h_new = (1 - z) * n + z * h_prev
700        // Create ones for (1 - z)
701        let shape = [batch_size, hidden_size];
702        let ones = Variable::new(
703            Tensor::from_vec(vec![1.0f32; batch_size * hidden_size], &shape).unwrap(),
704            false,
705        );
706        let one_minus_z = ones.sub_var(&z);
707
708        // h_new = one_minus_z * n + z * h_prev
709        one_minus_z.mul_var(&n).add_var(&z.mul_var(hidden))
710    }
711}
712
713impl Module for GRUCell {
714    fn forward(&self, input: &Variable) -> Variable {
715        let batch_size = input.shape()[0];
716
717        // Initialize hidden state to zeros
718        let hidden = Variable::new(
719            zeros(&[batch_size, self.hidden_size]),
720            input.requires_grad(),
721        );
722
723        self.forward_step(input, &hidden)
724    }
725
726    fn parameters(&self) -> Vec<Parameter> {
727        vec![
728            self.weight_ih.clone(),
729            self.weight_hh.clone(),
730            self.bias_ih.clone(),
731            self.bias_hh.clone(),
732        ]
733    }
734
735    fn named_parameters(&self) -> HashMap<String, Parameter> {
736        let mut params = HashMap::new();
737        params.insert("weight_ih".to_string(), self.weight_ih.clone());
738        params.insert("weight_hh".to_string(), self.weight_hh.clone());
739        params.insert("bias_ih".to_string(), self.bias_ih.clone());
740        params.insert("bias_hh".to_string(), self.bias_hh.clone());
741        params
742    }
743
744    fn name(&self) -> &'static str {
745        "GRUCell"
746    }
747}
748
749/// Multi-layer GRU.
750pub struct GRU {
751    /// GRU cells for each layer.
752    cells: Vec<GRUCell>,
753    /// Hidden state size.
754    hidden_size: usize,
755    /// Number of layers.
756    num_layers: usize,
757    /// If true, input is (batch, seq, features), else (seq, batch, features).
758    batch_first: bool,
759}
760
761impl GRU {
762    /// Creates a new multi-layer GRU.
763    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
764        let mut cells = Vec::with_capacity(num_layers);
765        cells.push(GRUCell::new(input_size, hidden_size));
766        for _ in 1..num_layers {
767            cells.push(GRUCell::new(hidden_size, hidden_size));
768        }
769        Self {
770            cells,
771            hidden_size,
772            num_layers,
773            batch_first: true,
774        }
775    }
776
777    /// Returns the hidden state size.
778    pub fn hidden_size(&self) -> usize {
779        self.hidden_size
780    }
781
782    /// Returns the number of layers.
783    pub fn num_layers(&self) -> usize {
784        self.num_layers
785    }
786}
787
788impl Module for GRU {
789    fn forward(&self, input: &Variable) -> Variable {
790        let shape = input.shape();
791        let (batch_size, seq_len, _input_size) = if self.batch_first {
792            (shape[0], shape[1], shape[2])
793        } else {
794            (shape[1], shape[0], shape[2])
795        };
796
797        // Initialize hidden states for all layers as Variables (with gradients)
798        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
799            .map(|_| {
800                Variable::new(
801                    zeros(&[batch_size, self.hidden_size]),
802                    input.requires_grad(),
803                )
804            })
805            .collect();
806
807        // Collect output Variables for each time step
808        let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
809
810        // Process each time step
811        for t in 0..seq_len {
812            // Extract input for this time step using narrow (preserves gradients)
813            // input shape: [batch, seq, features]
814            // narrow to [batch, 1, features], then reshape to [batch, features]
815            // narrow gives [batch, 1, features], reshape to [batch, features]
816            let narrowed = input.narrow(1, t, 1);
817            let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
818
819            // Process through each layer
820            let mut layer_input = step_input;
821
822            for (layer_idx, cell) in self.cells.iter().enumerate() {
823                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
824
825                // Update hidden state for this layer (keeps gradient chain)
826                hidden_states[layer_idx] = new_hidden.clone();
827
828                // Output of this layer becomes input to next layer
829                layer_input = new_hidden;
830            }
831
832            // Store output from last layer for this time step
833            output_vars.push(layer_input);
834        }
835
836        // Stack outputs along the time dimension
837        // Each output_var has shape [batch, hidden_size]
838        // We need to combine them into [batch, seq, hidden_size]
839        self.stack_outputs(&output_vars, batch_size, seq_len)
840    }
841
842    fn parameters(&self) -> Vec<Parameter> {
843        self.cells.iter().flat_map(|c| c.parameters()).collect()
844    }
845
846    fn named_parameters(&self) -> HashMap<String, Parameter> {
847        let mut params = HashMap::new();
848        if self.cells.len() == 1 {
849            for (n, p) in self.cells[0].named_parameters() {
850                params.insert(n, p);
851            }
852        } else {
853            for (i, cell) in self.cells.iter().enumerate() {
854                for (n, p) in cell.named_parameters() {
855                    params.insert(format!("cells.{i}.{n}"), p);
856                }
857            }
858        }
859        params
860    }
861
862    fn name(&self) -> &'static str {
863        "GRU"
864    }
865}
866
867impl GRU {
868    /// Forward pass that returns the mean of all hidden states.
869    /// This is equivalent to processing then mean pooling, but with proper gradient flow.
870    pub fn forward_mean(&self, input: &Variable) -> Variable {
871        let shape = input.shape();
872        let (batch_size, seq_len, _input_size) = if self.batch_first {
873            (shape[0], shape[1], shape[2])
874        } else {
875            (shape[1], shape[0], shape[2])
876        };
877
878        // Initialize hidden states for all layers as Variables (with gradients)
879        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
880            .map(|_| {
881                Variable::new(
882                    zeros(&[batch_size, self.hidden_size]),
883                    input.requires_grad(),
884                )
885            })
886            .collect();
887
888        // Accumulator for mean of outputs
889        let mut output_sum: Option<Variable> = None;
890
891        // Process each time step
892        for t in 0..seq_len {
893            // Extract input for this time step using narrow (preserves gradients)
894            // narrow gives [batch, 1, features], reshape to [batch, features]
895            let narrowed = input.narrow(1, t, 1);
896            let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
897
898            // Process through each layer
899            let mut layer_input = step_input;
900
901            for (layer_idx, cell) in self.cells.iter().enumerate() {
902                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
903                hidden_states[layer_idx] = new_hidden.clone();
904                layer_input = new_hidden;
905            }
906
907            // Accumulate output (last layer's hidden state)
908            output_sum = Some(match output_sum {
909                None => layer_input,
910                Some(acc) => acc.add_var(&layer_input),
911            });
912        }
913
914        // Return mean of all hidden states
915        match output_sum {
916            Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
917            None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
918        }
919    }
920
921    /// Forward pass that returns the last hidden state.
922    /// Good for sequence classification with proper gradient flow.
923    pub fn forward_last(&self, input: &Variable) -> Variable {
924        let shape = input.shape();
925        let (batch_size, seq_len, _input_size) = if self.batch_first {
926            (shape[0], shape[1], shape[2])
927        } else {
928            (shape[1], shape[0], shape[2])
929        };
930
931        // Initialize hidden states for all layers
932        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
933            .map(|_| {
934                Variable::new(
935                    zeros(&[batch_size, self.hidden_size]),
936                    input.requires_grad(),
937                )
938            })
939            .collect();
940
941        // Process each time step
942        for t in 0..seq_len {
943            // narrow gives [batch, 1, features], reshape to [batch, features]
944            let narrowed = input.narrow(1, t, 1);
945            let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
946
947            let mut layer_input = step_input;
948
949            for (layer_idx, cell) in self.cells.iter().enumerate() {
950                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
951                hidden_states[layer_idx] = new_hidden.clone();
952                layer_input = new_hidden;
953            }
954        }
955
956        // Return last hidden state from last layer
957        hidden_states
958            .pop()
959            .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
960    }
961
962    /// Stack output Variables into a single [batch, seq, hidden] tensor.
963    /// Note: This creates a new tensor without gradient connections to individual timesteps.
964    /// For gradient flow, use forward_mean() or forward_last() instead.
965    fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
966        if outputs.is_empty() {
967            return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
968        }
969
970        // Unsqueeze each (batch, hidden) → (batch, 1, hidden), then cat along dim=1
971        let unsqueezed: Vec<Variable> = outputs.iter()
972            .map(|o| o.unsqueeze(1))
973            .collect();
974        let refs: Vec<&Variable> = unsqueezed.iter().collect();
975        Variable::cat(&refs, 1)
976    }
977}
978
979// =============================================================================
980// Tests
981// =============================================================================
982
983#[cfg(test)]
984mod tests {
985    use super::*;
986
987    #[test]
988    fn test_rnn_cell() {
989        let cell = RNNCell::new(10, 20);
990        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
991        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
992        let output = cell.forward_step(&input, &hidden);
993        assert_eq!(output.shape(), vec![2, 20]);
994    }
995
996    #[test]
997    fn test_rnn() {
998        let rnn = RNN::new(10, 20, 2);
999        let input = Variable::new(
1000            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1001            false,
1002        );
1003        let output = rnn.forward(&input);
1004        assert_eq!(output.shape(), vec![2, 5, 20]);
1005    }
1006
1007    #[test]
1008    fn test_lstm() {
1009        let lstm = LSTM::new(10, 20, 1);
1010        let input = Variable::new(
1011            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1012            false,
1013        );
1014        let output = lstm.forward(&input);
1015        assert_eq!(output.shape(), vec![2, 5, 20]);
1016    }
1017
1018    #[test]
1019    fn test_gru_gradients_reach_parameters() {
1020        let gru = GRU::new(4, 8, 1);
1021        let input = Variable::new(
1022            Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1023            true,
1024        );
1025        let output = gru.forward(&input);
1026        println!("Output shape: {:?}, requires_grad: {}", output.shape(), output.requires_grad());
1027        let loss = output.sum();
1028        println!("Loss: {:?}, requires_grad: {}", loss.data().to_vec(), loss.requires_grad());
1029        loss.backward();
1030
1031        // Check input gradient
1032        println!("Input grad: {:?}", input.grad().map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>()));
1033
1034        let params = gru.parameters();
1035        println!("Number of parameters: {}", params.len());
1036        let mut has_grad = false;
1037        for (i, p) in params.iter().enumerate() {
1038            let grad = p.grad();
1039            match grad {
1040                Some(g) => {
1041                    let gv = g.to_vec();
1042                    let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1043                    println!("Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1044                        i, p.shape(), p.requires_grad(), sum_abs);
1045                    if sum_abs > 0.0 {
1046                        has_grad = true;
1047                    }
1048                }
1049                None => {
1050                    println!("Param {} shape {:?} requires_grad={}: NO GRADIENT",
1051                        i, p.shape(), p.requires_grad());
1052                }
1053            }
1054        }
1055        assert!(has_grad, "At least one GRU parameter should have non-zero gradients");
1056    }
1057}