Skip to main content

axonml_nn/layers/
rnn.rs

1//! Recurrent layers — `RNN`, `LSTM`, `GRU` and their cell variants.
2//!
3//! 1490 lines. `RNNCell` / `RNN` (tanh or relu nonlinearity, multi-layer,
4//! bidirectional), `LSTMCell` / `LSTM` (forget/input/output/cell gates,
5//! multi-layer, returns (output, (h_n, c_n))), `GRUCell` / `GRU`
6//! (update/reset gates, multi-layer). All support `forward_step` (single
7//! timestep) and sequence-level `forward` (all timesteps). Cell variants
8//! expose gate-level state for custom sequence handling.
9//!
10//! # File
11//! `crates/axonml-nn/src/layers/rnn.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28
29use crate::init::{xavier_uniform, zeros};
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33// =============================================================================
34// RNNCell
35// =============================================================================
36
37/// A single RNN cell.
38///
39/// h' = tanh(W_ih * x + b_ih + W_hh * h + b_hh)
40pub struct RNNCell {
41    /// Input-hidden weights.
42    pub weight_ih: Parameter,
43    /// Hidden-hidden weights.
44    pub weight_hh: Parameter,
45    /// Input-hidden bias.
46    pub bias_ih: Parameter,
47    /// Hidden-hidden bias.
48    pub bias_hh: Parameter,
49    /// Input size.
50    input_size: usize,
51    /// Hidden size.
52    hidden_size: usize,
53}
54
55impl RNNCell {
56    /// Creates a new RNNCell.
57    pub fn new(input_size: usize, hidden_size: usize) -> Self {
58        Self {
59            weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
60            weight_hh: Parameter::named(
61                "weight_hh",
62                xavier_uniform(hidden_size, hidden_size),
63                true,
64            ),
65            bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
66            bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
67            input_size,
68            hidden_size,
69        }
70    }
71
72    /// Returns the expected input size.
73    pub fn input_size(&self) -> usize {
74        self.input_size
75    }
76
77    /// Returns the hidden state size.
78    pub fn hidden_size(&self) -> usize {
79        self.hidden_size
80    }
81
82    /// Forward pass for a single time step.
83    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
84        let input_features = input.data().shape().last().copied().unwrap_or(0);
85        assert_eq!(
86            input_features, self.input_size,
87            "RNNCell: expected input size {}, got {}",
88            self.input_size, input_features
89        );
90        // x @ W_ih^T + b_ih
91        let weight_ih = self.weight_ih.variable();
92        let weight_ih_t = weight_ih.transpose(0, 1);
93        let ih = input.matmul(&weight_ih_t);
94        let bias_ih = self.bias_ih.variable();
95        let ih = ih.add_var(&bias_ih);
96
97        // h @ W_hh^T + b_hh
98        let weight_hh = self.weight_hh.variable();
99        let weight_hh_t = weight_hh.transpose(0, 1);
100        let hh = hidden.matmul(&weight_hh_t);
101        let bias_hh = self.bias_hh.variable();
102        let hh = hh.add_var(&bias_hh);
103
104        // tanh(ih + hh)
105        ih.add_var(&hh).tanh()
106    }
107}
108
109impl Module for RNNCell {
110    fn forward(&self, input: &Variable) -> Variable {
111        // Initialize hidden state to zeros
112        let batch_size = input.shape()[0];
113        let hidden = Variable::new(
114            zeros(&[batch_size, self.hidden_size]),
115            input.requires_grad(),
116        );
117        self.forward_step(input, &hidden)
118    }
119
120    fn parameters(&self) -> Vec<Parameter> {
121        vec![
122            self.weight_ih.clone(),
123            self.weight_hh.clone(),
124            self.bias_ih.clone(),
125            self.bias_hh.clone(),
126        ]
127    }
128
129    fn named_parameters(&self) -> HashMap<String, Parameter> {
130        let mut params = HashMap::new();
131        params.insert("weight_ih".to_string(), self.weight_ih.clone());
132        params.insert("weight_hh".to_string(), self.weight_hh.clone());
133        params.insert("bias_ih".to_string(), self.bias_ih.clone());
134        params.insert("bias_hh".to_string(), self.bias_hh.clone());
135        params
136    }
137
138    fn name(&self) -> &'static str {
139        "RNNCell"
140    }
141}
142
143// =============================================================================
144// RNN
145// =============================================================================
146
147/// Multi-layer RNN.
148///
149/// Processes sequences through stacked RNN layers.
150pub struct RNN {
151    /// RNN cells for each layer.
152    cells: Vec<RNNCell>,
153    /// Input size.
154    _input_size: usize,
155    /// Hidden size.
156    hidden_size: usize,
157    /// Number of layers.
158    num_layers: usize,
159    /// Batch first flag.
160    batch_first: bool,
161}
162
163impl RNN {
164    /// Creates a new multi-layer RNN.
165    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
166        Self::with_options(input_size, hidden_size, num_layers, true)
167    }
168
169    /// Creates an RNN with all options.
170    pub fn with_options(
171        input_size: usize,
172        hidden_size: usize,
173        num_layers: usize,
174        batch_first: bool,
175    ) -> Self {
176        let mut cells = Vec::with_capacity(num_layers);
177
178        // First layer takes input_size
179        cells.push(RNNCell::new(input_size, hidden_size));
180
181        // Subsequent layers take hidden_size
182        for _ in 1..num_layers {
183            cells.push(RNNCell::new(hidden_size, hidden_size));
184        }
185
186        Self {
187            cells,
188            _input_size: input_size,
189            hidden_size,
190            num_layers,
191            batch_first,
192        }
193    }
194}
195
196impl Module for RNN {
197    fn forward(&self, input: &Variable) -> Variable {
198        let shape = input.shape();
199        let (batch_size, seq_len, input_features) = if self.batch_first {
200            (shape[0], shape[1], shape[2])
201        } else {
202            (shape[1], shape[0], shape[2])
203        };
204
205        // Initialize hidden states
206        let mut hiddens: Vec<Variable> = (0..self.num_layers)
207            .map(|_| {
208                Variable::new(
209                    zeros(&[batch_size, self.hidden_size]),
210                    input.requires_grad(),
211                )
212            })
213            .collect();
214
215        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
216        let cell0 = &self.cells[0];
217        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
218        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
219        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
220        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, self.hidden_size]);
221
222        // Hoist weight transposes out of the per-timestep loop
223        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
224        let bias_hh_0 = cell0.bias_hh.variable();
225
226        let mut outputs = Vec::with_capacity(seq_len);
227
228        for t in 0..seq_len {
229            // Layer 0: use pre-computed ih projection + hoisted weight transpose
230            let ih_t = ih_all_3d.select(1, t);
231            let hh = hiddens[0].matmul(&w_hh_t_0).add_var(&bias_hh_0);
232            hiddens[0] = ih_t.add_var(&hh).tanh();
233
234            // Subsequent layers
235            for l in 1..self.num_layers {
236                let layer_input = hiddens[l - 1].clone();
237                hiddens[l] = self.cells[l].forward_step(&layer_input, &hiddens[l]);
238            }
239
240            outputs.push(hiddens[self.num_layers - 1].clone());
241        }
242
243        // Stack outputs using graph-tracked cat (unsqueeze + cat along time dim)
244        let time_dim = usize::from(self.batch_first);
245        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
246        let refs: Vec<&Variable> = unsqueezed.iter().collect();
247        Variable::cat(&refs, time_dim)
248    }
249
250    fn parameters(&self) -> Vec<Parameter> {
251        self.cells.iter().flat_map(|c| c.parameters()).collect()
252    }
253
254    fn name(&self) -> &'static str {
255        "RNN"
256    }
257}
258
259// =============================================================================
260// LSTMCell
261// =============================================================================
262
263/// A single LSTM cell.
264pub struct LSTMCell {
265    /// Input-hidden weights for all gates.
266    pub weight_ih: Parameter,
267    /// Hidden-hidden weights for all gates.
268    pub weight_hh: Parameter,
269    /// Input-hidden bias for all gates.
270    pub bias_ih: Parameter,
271    /// Hidden-hidden bias for all gates.
272    pub bias_hh: Parameter,
273    /// Input size.
274    input_size: usize,
275    /// Hidden size.
276    hidden_size: usize,
277}
278
279impl LSTMCell {
280    /// Creates a new LSTMCell.
281    pub fn new(input_size: usize, hidden_size: usize) -> Self {
282        // LSTM has 4 gates, so weight size is 4*hidden_size
283        Self {
284            weight_ih: Parameter::named(
285                "weight_ih",
286                xavier_uniform(input_size, 4 * hidden_size),
287                true,
288            ),
289            weight_hh: Parameter::named(
290                "weight_hh",
291                xavier_uniform(hidden_size, 4 * hidden_size),
292                true,
293            ),
294            bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
295            bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
296            input_size,
297            hidden_size,
298        }
299    }
300
301    /// Returns the expected input size.
302    pub fn input_size(&self) -> usize {
303        self.input_size
304    }
305
306    /// Returns the hidden state size.
307    pub fn hidden_size(&self) -> usize {
308        self.hidden_size
309    }
310
311    /// Forward pass returning (h', c').
312    pub fn forward_step(
313        &self,
314        input: &Variable,
315        hx: &(Variable, Variable),
316    ) -> (Variable, Variable) {
317        let input_features = input.data().shape().last().copied().unwrap_or(0);
318        assert_eq!(
319            input_features, self.input_size,
320            "LSTMCell: expected input size {}, got {}",
321            self.input_size, input_features
322        );
323
324        let (h, c) = hx;
325
326        // Compute all gates at once (x @ W^T + b)
327        let weight_ih = self.weight_ih.variable();
328        let weight_ih_t = weight_ih.transpose(0, 1);
329        let ih = input.matmul(&weight_ih_t);
330        let bias_ih = self.bias_ih.variable();
331        let ih = ih.add_var(&bias_ih);
332
333        let weight_hh = self.weight_hh.variable();
334        let weight_hh_t = weight_hh.transpose(0, 1);
335        let hh = h.matmul(&weight_hh_t);
336        let bias_hh = self.bias_hh.variable();
337        let hh = hh.add_var(&bias_hh);
338
339        let gates = ih.add_var(&hh);
340        let hs = self.hidden_size;
341
342        // Split into 4 gates using narrow (preserves gradient flow)
343        let i = gates.narrow(1, 0, hs).sigmoid();
344        let f = gates.narrow(1, hs, hs).sigmoid();
345        let g = gates.narrow(1, 2 * hs, hs).tanh();
346        let o = gates.narrow(1, 3 * hs, hs).sigmoid();
347
348        // c' = f * c + i * g
349        let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
350
351        // h' = o * tanh(c')
352        let h_new = o.mul_var(&c_new.tanh());
353
354        (h_new, c_new)
355    }
356}
357
358impl Module for LSTMCell {
359    fn forward(&self, input: &Variable) -> Variable {
360        let batch_size = input.shape()[0];
361        let h = Variable::new(
362            zeros(&[batch_size, self.hidden_size]),
363            input.requires_grad(),
364        );
365        let c = Variable::new(
366            zeros(&[batch_size, self.hidden_size]),
367            input.requires_grad(),
368        );
369        let (h_new, _) = self.forward_step(input, &(h, c));
370        h_new
371    }
372
373    fn parameters(&self) -> Vec<Parameter> {
374        vec![
375            self.weight_ih.clone(),
376            self.weight_hh.clone(),
377            self.bias_ih.clone(),
378            self.bias_hh.clone(),
379        ]
380    }
381
382    fn named_parameters(&self) -> HashMap<String, Parameter> {
383        let mut params = HashMap::new();
384        params.insert("weight_ih".to_string(), self.weight_ih.clone());
385        params.insert("weight_hh".to_string(), self.weight_hh.clone());
386        params.insert("bias_ih".to_string(), self.bias_ih.clone());
387        params.insert("bias_hh".to_string(), self.bias_hh.clone());
388        params
389    }
390
391    fn name(&self) -> &'static str {
392        "LSTMCell"
393    }
394}
395
396// =============================================================================
397// LSTM
398// =============================================================================
399
400/// Multi-layer LSTM.
401pub struct LSTM {
402    /// LSTM cells for each layer.
403    cells: Vec<LSTMCell>,
404    /// Input size.
405    input_size: usize,
406    /// Hidden size.
407    hidden_size: usize,
408    /// Number of layers.
409    num_layers: usize,
410    /// Batch first flag.
411    batch_first: bool,
412}
413
414impl LSTM {
415    /// Creates a new multi-layer LSTM.
416    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
417        Self::with_options(input_size, hidden_size, num_layers, true)
418    }
419
420    /// Creates an LSTM with all options.
421    pub fn with_options(
422        input_size: usize,
423        hidden_size: usize,
424        num_layers: usize,
425        batch_first: bool,
426    ) -> Self {
427        let mut cells = Vec::with_capacity(num_layers);
428        cells.push(LSTMCell::new(input_size, hidden_size));
429        for _ in 1..num_layers {
430            cells.push(LSTMCell::new(hidden_size, hidden_size));
431        }
432
433        Self {
434            cells,
435            input_size,
436            hidden_size,
437            num_layers,
438            batch_first,
439        }
440    }
441
442    /// Returns the expected input size.
443    pub fn input_size(&self) -> usize {
444        self.input_size
445    }
446
447    /// Returns the hidden state size.
448    pub fn hidden_size(&self) -> usize {
449        self.hidden_size
450    }
451
452    /// Returns the number of layers.
453    pub fn num_layers(&self) -> usize {
454        self.num_layers
455    }
456}
457
458impl Module for LSTM {
459    fn forward(&self, input: &Variable) -> Variable {
460        let shape = input.shape();
461        let (batch_size, seq_len, input_features) = if self.batch_first {
462            (shape[0], shape[1], shape[2])
463        } else {
464            (shape[1], shape[0], shape[2])
465        };
466
467        let lstm_input_device = input.data().device();
468        #[cfg(feature = "cuda")]
469        let lstm_on_gpu = lstm_input_device.is_gpu();
470        #[cfg(not(feature = "cuda"))]
471        let lstm_on_gpu = false;
472
473        let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
474            .map(|_| {
475                let make_h = || {
476                    let h_cpu = zeros(&[batch_size, self.hidden_size]);
477                    let h_tensor = if lstm_on_gpu {
478                        h_cpu
479                            .to_device(lstm_input_device)
480                            .expect("LSTM: failed to move hidden state to GPU")
481                    } else {
482                        h_cpu
483                    };
484                    Variable::new(h_tensor, input.requires_grad())
485                };
486                (make_h(), make_h())
487            })
488            .collect();
489
490        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
491        // input: [batch, seq, features] -> reshaped to [batch*seq, features]
492        // ih_all: [batch*seq, 4*hidden] = input_2d @ W_ih^T + bias_ih
493        // Note: matmul auto-dispatches to cuBLAS GEMM when tensors are on GPU
494        let cell0 = &self.cells[0];
495        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
496        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
497        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
498        // ih_all_3d: [batch, seq, 4*hidden]
499        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 4 * self.hidden_size]);
500
501        // Hoist weight transpose + bias out of the per-timestep loop
502        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
503        let bias_hh_0 = cell0.bias_hh.variable();
504
505        let mut outputs = Vec::with_capacity(seq_len);
506
507        // Check if we're on GPU for fused gate kernel path
508        #[cfg(feature = "cuda")]
509        let on_gpu = input.data().device().is_gpu();
510        #[cfg(not(feature = "cuda"))]
511        let on_gpu = false;
512
513        for t in 0..seq_len {
514            // Layer 0: use pre-computed ih projection + hoisted weight transpose
515            let ih_t = ih_all_3d.select(1, t);
516            let (h, c) = &states[0];
517
518            // h @ W_hh^T + bias_hh (cuBLAS on GPU, matrixmultiply on CPU)
519            let hh = h.matmul(&w_hh_t_0).add_var(&bias_hh_0);
520
521            // Combined gates = ih + hh
522            let gates = ih_t.add_var(&hh);
523
524            if on_gpu {
525                // GPU path: fused LSTM gate kernel (1 launch vs ~14 separate ops)
526                // gates [batch, 4*hidden], c [batch, hidden] → h_new, c_new [batch, hidden]
527                #[cfg(feature = "cuda")]
528                {
529                    let hs = self.hidden_size;
530                    let gates_data = gates.data();
531                    let c_data = c.data();
532
533                    if let Some((h_tensor, c_tensor)) = gates_data.lstm_gates_fused(&c_data, hs) {
534                        // Save forward state for backward
535                        let saved_gates = gates_data.clone();
536                        let saved_c_prev = c_data.clone();
537                        let saved_c_new = c_tensor.clone();
538
539                        // Create proper backward that calls LSTM backward kernel
540                        let backward_fn = axonml_autograd::LstmGatesBackward::new(
541                            gates.grad_fn().cloned(),
542                            c.grad_fn().cloned(),
543                            saved_gates,
544                            saved_c_prev,
545                            saved_c_new,
546                            hs,
547                        );
548                        let grad_fn = axonml_autograd::GradFn::new(backward_fn);
549
550                        let fused_requires_grad = gates.requires_grad() || c.requires_grad();
551                        let h_new = Variable::from_operation(
552                            h_tensor,
553                            grad_fn.clone(),
554                            fused_requires_grad,
555                        );
556                        let c_new =
557                            Variable::from_operation(c_tensor, grad_fn, fused_requires_grad);
558                        states[0] = (h_new, c_new);
559                    }
560                }
561            } else {
562                // CPU path: individual ops (each autograd-tracked)
563                let hs = self.hidden_size;
564                let i_gate = gates.narrow(1, 0, hs).sigmoid();
565                let f_gate = gates.narrow(1, hs, hs).sigmoid();
566                let g_gate = gates.narrow(1, 2 * hs, hs).tanh();
567                let o_gate = gates.narrow(1, 3 * hs, hs).sigmoid();
568                let c_new = f_gate.mul_var(c).add_var(&i_gate.mul_var(&g_gate));
569                let h_new = o_gate.mul_var(&c_new.tanh());
570                states[0] = (h_new, c_new);
571            }
572
573            // Subsequent layers use the regular cell forward_step
574            for l in 1..self.num_layers {
575                let layer_input = states[l - 1].0.clone();
576                states[l] = self.cells[l].forward_step(&layer_input, &states[l]);
577            }
578
579            outputs.push(states[self.num_layers - 1].0.clone());
580        }
581
582        // Stack outputs along the time dimension
583        let time_dim = usize::from(self.batch_first);
584        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
585        let refs: Vec<&Variable> = unsqueezed.iter().collect();
586        Variable::cat(&refs, time_dim)
587    }
588
589    fn parameters(&self) -> Vec<Parameter> {
590        self.cells.iter().flat_map(|c| c.parameters()).collect()
591    }
592
593    fn named_parameters(&self) -> HashMap<String, Parameter> {
594        let mut params = HashMap::new();
595        if self.cells.len() == 1 {
596            // Single layer: expose directly without cell index prefix
597            for (n, p) in self.cells[0].named_parameters() {
598                params.insert(n, p);
599            }
600        } else {
601            for (i, cell) in self.cells.iter().enumerate() {
602                for (n, p) in cell.named_parameters() {
603                    params.insert(format!("cells.{i}.{n}"), p);
604                }
605            }
606        }
607        params
608    }
609
610    fn name(&self) -> &'static str {
611        "LSTM"
612    }
613}
614
615// =============================================================================
616// GRUCell and GRU
617// =============================================================================
618
619/// A single GRU cell.
620///
621/// h' = (1 - z) * n + z * h
622/// where:
623///   r = sigmoid(W_ir * x + b_ir + W_hr * h + b_hr)  (reset gate)
624///   z = sigmoid(W_iz * x + b_iz + W_hz * h + b_hz)  (update gate)
625///   n = tanh(W_in * x + b_in + r * (W_hn * h + b_hn))  (new gate)
626pub struct GRUCell {
627    /// Input-hidden weights for all gates (reset, update, new).
628    pub weight_ih: Parameter,
629    /// Hidden-hidden weights for all gates (reset, update, new).
630    pub weight_hh: Parameter,
631    /// Input-hidden bias for all gates.
632    pub bias_ih: Parameter,
633    /// Hidden-hidden bias for all gates.
634    pub bias_hh: Parameter,
635    /// Input size.
636    input_size: usize,
637    /// Hidden size.
638    hidden_size: usize,
639}
640
641impl GRUCell {
642    /// Creates a new GRU cell.
643    pub fn new(input_size: usize, hidden_size: usize) -> Self {
644        Self {
645            weight_ih: Parameter::named(
646                "weight_ih",
647                xavier_uniform(input_size, 3 * hidden_size),
648                true,
649            ),
650            weight_hh: Parameter::named(
651                "weight_hh",
652                xavier_uniform(hidden_size, 3 * hidden_size),
653                true,
654            ),
655            bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
656            bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
657            input_size,
658            hidden_size,
659        }
660    }
661
662    /// Returns the expected input size.
663    pub fn input_size(&self) -> usize {
664        self.input_size
665    }
666
667    /// Returns the hidden state size.
668    pub fn hidden_size(&self) -> usize {
669        self.hidden_size
670    }
671}
672
673impl GRUCell {
674    /// Forward pass for a single time step with explicit hidden state.
675    ///
676    /// GRU equations:
677    /// r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)
678    /// z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)
679    /// n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
680    /// h_t = (1 - z_t) * n_t + z_t * h_{t-1}
681    ///
682    /// All computations use Variable operations for proper gradient flow.
683    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
684        let _batch_size = input.shape()[0];
685        let hidden_size = self.hidden_size;
686
687        // Get weight matrices
688        let weight_ih = self.weight_ih.variable();
689        let weight_hh = self.weight_hh.variable();
690        let bias_ih = self.bias_ih.variable();
691        let bias_hh = self.bias_hh.variable();
692
693        // Compute input transformation: x @ W_ih^T + b_ih
694        // Shape: [batch, 3*hidden_size]
695        let weight_ih_t = weight_ih.transpose(0, 1);
696        let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
697
698        // Compute hidden transformation: h @ W_hh^T + b_hh
699        // Shape: [batch, 3*hidden_size]
700        let weight_hh_t = weight_hh.transpose(0, 1);
701        let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
702
703        // Use narrow to split into gates (preserves gradient flow)
704        // Each gate slice: [batch, hidden_size]
705        let ih_r = ih.narrow(1, 0, hidden_size);
706        let ih_z = ih.narrow(1, hidden_size, hidden_size);
707        let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
708
709        let hh_r = hh.narrow(1, 0, hidden_size);
710        let hh_z = hh.narrow(1, hidden_size, hidden_size);
711        let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
712
713        // Compute gates using Variable operations for gradient flow
714        // r = sigmoid(ih_r + hh_r)
715        let r = ih_r.add_var(&hh_r).sigmoid();
716
717        // z = sigmoid(ih_z + hh_z)
718        let z = ih_z.add_var(&hh_z).sigmoid();
719
720        // n = tanh(ih_n + r * hh_n)
721        let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
722
723        // h_new = (1 - z) * n + z * h_prev
724        // Rewritten as: n + z * (h_prev - n)  to avoid allocating a ones tensor
725        let h_minus_n = hidden.sub_var(&n);
726        n.add_var(&z.mul_var(&h_minus_n))
727    }
728}
729
730impl Module for GRUCell {
731    fn forward(&self, input: &Variable) -> Variable {
732        let batch_size = input.shape()[0];
733
734        // Initialize hidden state to zeros
735        let hidden = Variable::new(
736            zeros(&[batch_size, self.hidden_size]),
737            input.requires_grad(),
738        );
739
740        self.forward_step(input, &hidden)
741    }
742
743    fn parameters(&self) -> Vec<Parameter> {
744        vec![
745            self.weight_ih.clone(),
746            self.weight_hh.clone(),
747            self.bias_ih.clone(),
748            self.bias_hh.clone(),
749        ]
750    }
751
752    fn named_parameters(&self) -> HashMap<String, Parameter> {
753        let mut params = HashMap::new();
754        params.insert("weight_ih".to_string(), self.weight_ih.clone());
755        params.insert("weight_hh".to_string(), self.weight_hh.clone());
756        params.insert("bias_ih".to_string(), self.bias_ih.clone());
757        params.insert("bias_hh".to_string(), self.bias_hh.clone());
758        params
759    }
760
761    fn name(&self) -> &'static str {
762        "GRUCell"
763    }
764}
765
766/// Multi-layer GRU.
767pub struct GRU {
768    /// GRU cells for each layer.
769    cells: Vec<GRUCell>,
770    /// Hidden state size.
771    hidden_size: usize,
772    /// Number of layers.
773    num_layers: usize,
774    /// If true, input is (batch, seq, features), else (seq, batch, features).
775    batch_first: bool,
776}
777
778impl GRU {
779    /// Creates a new multi-layer GRU.
780    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
781        let mut cells = Vec::with_capacity(num_layers);
782        cells.push(GRUCell::new(input_size, hidden_size));
783        for _ in 1..num_layers {
784            cells.push(GRUCell::new(hidden_size, hidden_size));
785        }
786        Self {
787            cells,
788            hidden_size,
789            num_layers,
790            batch_first: true,
791        }
792    }
793
794    /// Returns the hidden state size.
795    pub fn hidden_size(&self) -> usize {
796        self.hidden_size
797    }
798
799    /// Returns the number of layers.
800    pub fn num_layers(&self) -> usize {
801        self.num_layers
802    }
803}
804
805impl Module for GRU {
806    fn forward(&self, input: &Variable) -> Variable {
807        let shape = input.shape();
808        let (batch_size, seq_len, input_features) = if self.batch_first {
809            (shape[0], shape[1], shape[2])
810        } else {
811            (shape[1], shape[0], shape[2])
812        };
813
814        // Check if we're on GPU for fused gate kernel path
815        #[cfg(feature = "cuda")]
816        let on_gpu = input.data().device().is_gpu();
817        #[cfg(not(feature = "cuda"))]
818        let on_gpu = false;
819
820        let input_device = input.data().device();
821
822        // Initialize hidden states for all layers as Variables (with gradients)
823        // Move to the same device as input so GPU fused kernels receive GPU tensors.
824        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
825            .map(|_| {
826                let h_cpu = zeros(&[batch_size, self.hidden_size]);
827                let h_tensor = if on_gpu {
828                    h_cpu
829                        .to_device(input_device)
830                        .expect("GRU: failed to move hidden state to GPU")
831                } else {
832                    h_cpu
833                };
834                Variable::new(h_tensor, input.requires_grad())
835            })
836            .collect();
837
838        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
839        // One big matmul instead of seq_len small ones
840        let cell0 = &self.cells[0];
841        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
842        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
843        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
844        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
845
846        // Hoist weight transpose + bias out of the per-timestep loop
847        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
848        let bias_hh_0 = cell0.bias_hh.variable();
849
850        let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
851
852        for t in 0..seq_len {
853            // Layer 0: use pre-computed ih projection + hoisted weight transpose
854            let ih_t = ih_all_3d.select(1, t);
855            let hidden = &hidden_states[0];
856            let hs = self.hidden_size;
857
858            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
859
860            if on_gpu {
861                // GPU path: fused GRU gate kernel (1 launch vs ~12 separate ops)
862                // ih_t [batch, 3*hidden], hh [batch, 3*hidden], hidden [batch, hidden] → h_new [batch, hidden]
863                #[cfg(feature = "cuda")]
864                {
865                    let ih_data = ih_t.data();
866                    let hh_data = hh.data();
867                    let h_data = hidden.data();
868
869                    if let Some(h_tensor) = ih_data.gru_gates_fused(&hh_data, &h_data, hs) {
870                        // Save forward state for backward
871                        let saved_ih = ih_data.clone();
872                        let saved_hh = hh_data.clone();
873                        let saved_h_prev = h_data.clone();
874
875                        // Create proper backward that calls GRU backward kernel
876                        let backward_fn = axonml_autograd::GruGatesBackward::new(
877                            ih_t.grad_fn().cloned(),
878                            hh.grad_fn().cloned(),
879                            hidden.grad_fn().cloned(),
880                            saved_ih,
881                            saved_hh,
882                            saved_h_prev,
883                            hs,
884                        );
885                        let grad_fn = axonml_autograd::GradFn::new(backward_fn);
886
887                        // Use requires_grad=true if ANY input to the fused op
888                        // requires grad — the GRU parameters (w_ih, w_hh, bias)
889                        // always require grad during training, so ih_t and hh
890                        // will have requires_grad=true even when the raw input
891                        // Variable does not.
892                        let fused_requires_grad =
893                            ih_t.requires_grad() || hh.requires_grad() || hidden.requires_grad();
894                        let h_new =
895                            Variable::from_operation(h_tensor, grad_fn, fused_requires_grad);
896                        hidden_states[0] = h_new;
897                    }
898                }
899            } else {
900                // CPU path: individual ops (each autograd-tracked)
901                let ih_r = ih_t.narrow(1, 0, hs);
902                let ih_z = ih_t.narrow(1, hs, hs);
903                let ih_n = ih_t.narrow(1, 2 * hs, hs);
904                let hh_r = hh.narrow(1, 0, hs);
905                let hh_z = hh.narrow(1, hs, hs);
906                let hh_n = hh.narrow(1, 2 * hs, hs);
907
908                let r = ih_r.add_var(&hh_r).sigmoid();
909                let z = ih_z.add_var(&hh_z).sigmoid();
910                let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
911                let h_minus_n = hidden.sub_var(&n);
912                let h_new = n.add_var(&z.mul_var(&h_minus_n));
913                hidden_states[0] = h_new;
914            }
915
916            // Subsequent layers use the regular cell forward_step
917            let mut layer_output = hidden_states[0].clone();
918            for l in 1..self.num_layers {
919                let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
920                hidden_states[l] = new_hidden.clone();
921                layer_output = new_hidden;
922            }
923
924            output_vars.push(layer_output);
925        }
926
927        // Stack outputs along the time dimension
928        self.stack_outputs(&output_vars, batch_size, seq_len)
929    }
930
931    fn parameters(&self) -> Vec<Parameter> {
932        self.cells.iter().flat_map(|c| c.parameters()).collect()
933    }
934
935    fn named_parameters(&self) -> HashMap<String, Parameter> {
936        let mut params = HashMap::new();
937        if self.cells.len() == 1 {
938            for (n, p) in self.cells[0].named_parameters() {
939                params.insert(n, p);
940            }
941        } else {
942            for (i, cell) in self.cells.iter().enumerate() {
943                for (n, p) in cell.named_parameters() {
944                    params.insert(format!("cells.{i}.{n}"), p);
945                }
946            }
947        }
948        params
949    }
950
951    fn name(&self) -> &'static str {
952        "GRU"
953    }
954}
955
956impl GRU {
957    /// Forward pass that returns the mean of all hidden states.
958    /// This is equivalent to processing then mean pooling, but with proper gradient flow.
959    pub fn forward_mean(&self, input: &Variable) -> Variable {
960        let shape = input.shape();
961        let (batch_size, seq_len, input_features) = if self.batch_first {
962            (shape[0], shape[1], shape[2])
963        } else {
964            (shape[1], shape[0], shape[2])
965        };
966
967        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
968            .map(|_| {
969                Variable::new(
970                    zeros(&[batch_size, self.hidden_size]),
971                    input.requires_grad(),
972                )
973            })
974            .collect();
975
976        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
977        let cell0 = &self.cells[0];
978        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
979        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
980        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
981        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
982
983        // Hoist weight transpose + bias out of per-timestep loop
984        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
985        let bias_hh_0 = cell0.bias_hh.variable();
986
987        let mut output_sum: Option<Variable> = None;
988        let hs = self.hidden_size;
989
990        for t in 0..seq_len {
991            // Layer 0: use pre-computed ih projection + hoisted weight transpose
992            let ih_t = ih_all_3d.select(1, t);
993            let hidden = &hidden_states[0];
994            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
995
996            let ih_r = ih_t.narrow(1, 0, hs);
997            let ih_z = ih_t.narrow(1, hs, hs);
998            let ih_n = ih_t.narrow(1, 2 * hs, hs);
999            let hh_r = hh.narrow(1, 0, hs);
1000            let hh_z = hh.narrow(1, hs, hs);
1001            let hh_n = hh.narrow(1, 2 * hs, hs);
1002
1003            let r = ih_r.add_var(&hh_r).sigmoid();
1004            let z = ih_z.add_var(&hh_z).sigmoid();
1005            let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1006            let h_minus_n = hidden.sub_var(&n);
1007            let h_new = n.add_var(&z.mul_var(&h_minus_n));
1008            hidden_states[0] = h_new.clone();
1009
1010            // Subsequent layers
1011            let mut layer_output = h_new;
1012            for l in 1..self.num_layers {
1013                let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
1014                hidden_states[l] = new_hidden.clone();
1015                layer_output = new_hidden;
1016            }
1017
1018            output_sum = Some(match output_sum {
1019                None => layer_output,
1020                Some(acc) => acc.add_var(&layer_output),
1021            });
1022        }
1023
1024        match output_sum {
1025            Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
1026            None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
1027        }
1028    }
1029
1030    /// Forward pass that returns the last hidden state.
1031    /// Good for sequence classification with proper gradient flow.
1032    pub fn forward_last(&self, input: &Variable) -> Variable {
1033        let shape = input.shape();
1034        let (batch_size, seq_len, input_features) = if self.batch_first {
1035            (shape[0], shape[1], shape[2])
1036        } else {
1037            (shape[1], shape[0], shape[2])
1038        };
1039
1040        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
1041            .map(|_| {
1042                Variable::new(
1043                    zeros(&[batch_size, self.hidden_size]),
1044                    input.requires_grad(),
1045                )
1046            })
1047            .collect();
1048
1049        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
1050        let cell0 = &self.cells[0];
1051        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
1052        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
1053        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
1054        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
1055
1056        // Hoist weight transpose + bias out of per-timestep loop
1057        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
1058        let bias_hh_0 = cell0.bias_hh.variable();
1059        let hs = self.hidden_size;
1060
1061        for t in 0..seq_len {
1062            // Layer 0: use pre-computed ih projection + hoisted weight transpose
1063            let ih_t = ih_all_3d.select(1, t);
1064            let hidden = &hidden_states[0];
1065            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
1066
1067            let ih_r = ih_t.narrow(1, 0, hs);
1068            let ih_z = ih_t.narrow(1, hs, hs);
1069            let ih_n = ih_t.narrow(1, 2 * hs, hs);
1070            let hh_r = hh.narrow(1, 0, hs);
1071            let hh_z = hh.narrow(1, hs, hs);
1072            let hh_n = hh.narrow(1, 2 * hs, hs);
1073
1074            let r = ih_r.add_var(&hh_r).sigmoid();
1075            let z = ih_z.add_var(&hh_z).sigmoid();
1076            let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1077            let h_minus_n = hidden.sub_var(&n);
1078            let h_new = n.add_var(&z.mul_var(&h_minus_n));
1079            hidden_states[0] = h_new.clone();
1080
1081            // Subsequent layers
1082            let mut layer_input = h_new;
1083
1084            for (layer_idx, cell) in self.cells.iter().enumerate().skip(1) {
1085                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
1086                hidden_states[layer_idx] = new_hidden.clone();
1087                layer_input = new_hidden;
1088            }
1089        }
1090
1091        // Return last hidden state from last layer
1092        hidden_states
1093            .pop()
1094            .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
1095    }
1096
1097    /// Stack output Variables into a single [batch, seq, hidden] tensor.
1098    /// Note: This creates a new tensor without gradient connections to individual timesteps.
1099    /// For gradient flow, use forward_mean() or forward_last() instead.
1100    fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
1101        if outputs.is_empty() {
1102            return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
1103        }
1104
1105        // Unsqueeze each (batch, hidden) → (batch, 1, hidden), then cat along dim=1
1106        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(1)).collect();
1107        let refs: Vec<&Variable> = unsqueezed.iter().collect();
1108        Variable::cat(&refs, 1)
1109    }
1110}
1111
1112// =============================================================================
1113// Tests
1114// =============================================================================
1115
1116#[cfg(test)]
1117mod tests {
1118    use super::*;
1119    use axonml_tensor::Tensor;
1120
1121    #[test]
1122    fn test_rnn_cell() {
1123        let cell = RNNCell::new(10, 20);
1124        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
1125        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
1126        let output = cell.forward_step(&input, &hidden);
1127        assert_eq!(output.shape(), vec![2, 20]);
1128    }
1129
1130    #[test]
1131    fn test_rnn() {
1132        let rnn = RNN::new(10, 20, 2);
1133        let input = Variable::new(
1134            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1135            false,
1136        );
1137        let output = rnn.forward(&input);
1138        assert_eq!(output.shape(), vec![2, 5, 20]);
1139    }
1140
1141    #[test]
1142    fn test_lstm() {
1143        let lstm = LSTM::new(10, 20, 1);
1144        let input = Variable::new(
1145            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1146            false,
1147        );
1148        let output = lstm.forward(&input);
1149        assert_eq!(output.shape(), vec![2, 5, 20]);
1150    }
1151
1152    #[test]
1153    fn test_gru_gradients_reach_parameters() {
1154        let gru = GRU::new(4, 8, 1);
1155        let input = Variable::new(
1156            Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1157            true,
1158        );
1159        let output = gru.forward(&input);
1160        println!(
1161            "Output shape: {:?}, requires_grad: {}",
1162            output.shape(),
1163            output.requires_grad()
1164        );
1165        let loss = output.sum();
1166        println!(
1167            "Loss: {:?}, requires_grad: {}",
1168            loss.data().to_vec(),
1169            loss.requires_grad()
1170        );
1171        loss.backward();
1172
1173        // Check input gradient
1174        println!(
1175            "Input grad: {:?}",
1176            input
1177                .grad()
1178                .map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>())
1179        );
1180
1181        let params = gru.parameters();
1182        println!("Number of parameters: {}", params.len());
1183        let mut has_grad = false;
1184        for (i, p) in params.iter().enumerate() {
1185            let grad = p.grad();
1186            match grad {
1187                Some(g) => {
1188                    let gv = g.to_vec();
1189                    let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1190                    println!(
1191                        "Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1192                        i,
1193                        p.shape(),
1194                        p.requires_grad(),
1195                        sum_abs
1196                    );
1197                    if sum_abs > 0.0 {
1198                        has_grad = true;
1199                    }
1200                }
1201                None => {
1202                    println!(
1203                        "Param {} shape {:?} requires_grad={}: NO GRADIENT",
1204                        i,
1205                        p.shape(),
1206                        p.requires_grad()
1207                    );
1208                }
1209            }
1210        }
1211        assert!(
1212            has_grad,
1213            "At least one GRU parameter should have non-zero gradients"
1214        );
1215    }
1216
1217    // =========================================================================
1218    // LSTM Comprehensive
1219    // =========================================================================
1220
1221    #[test]
1222    fn test_lstm_cell_forward_step() {
1223        let cell = LSTMCell::new(8, 16);
1224        let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1225        let hidden = Variable::new(
1226            Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1227            false,
1228        );
1229        let cell_state = Variable::new(
1230            Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1231            false,
1232        );
1233        let hx = (hidden, cell_state);
1234        let (h, c) = cell.forward_step(&input, &hx);
1235        assert_eq!(h.shape(), vec![2, 16]);
1236        assert_eq!(c.shape(), vec![2, 16]);
1237    }
1238
1239    #[test]
1240    fn test_lstm_multi_layer() {
1241        let lstm = LSTM::new(8, 16, 3); // 3 layers
1242        assert_eq!(lstm.num_layers(), 3);
1243        assert_eq!(lstm.hidden_size(), 16);
1244
1245        let input = Variable::new(
1246            Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1247            false,
1248        );
1249        let output = lstm.forward(&input);
1250        assert_eq!(output.shape(), vec![2, 5, 16]);
1251    }
1252
1253    #[test]
1254    fn test_lstm_forward_last() {
1255        let lstm = LSTM::new(8, 16, 1);
1256        let input = Variable::new(
1257            Tensor::from_vec(vec![1.0; 2 * 10 * 8], &[2, 10, 8]).unwrap(),
1258            false,
1259        );
1260        // forward_last should return only the last time step
1261        // The LSTM module may not have forward_last, but forward returns [B, T, H]
1262        let output = lstm.forward(&input);
1263        assert_eq!(output.shape(), vec![2, 10, 16]);
1264
1265        // Last timestep extraction
1266        let out_vec = output.data().to_vec();
1267        let last_t0 = &out_vec[9 * 16..10 * 16]; // batch 0, time 9
1268        assert!(
1269            last_t0.iter().all(|v| v.is_finite()),
1270            "Last output should be finite"
1271        );
1272    }
1273
1274    #[test]
1275    fn test_lstm_gradient_flow() {
1276        let lstm = LSTM::new(4, 8, 1);
1277        let input = Variable::new(
1278            Tensor::from_vec(vec![0.5; 3 * 4], &[1, 3, 4]).unwrap(),
1279            true,
1280        );
1281        let output = lstm.forward(&input);
1282        let loss = output.sum();
1283        loss.backward();
1284
1285        let input_grad = input
1286            .grad()
1287            .expect("Input should have gradient through LSTM");
1288        assert_eq!(input_grad.shape(), &[1, 3, 4]);
1289        assert!(
1290            input_grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1291            "LSTM should propagate gradients to input"
1292        );
1293
1294        // Parameters should also have gradients
1295        let params = lstm.parameters();
1296        let grads_exist = params.iter().any(|p| {
1297            p.grad()
1298                .is_some_and(|g| g.to_vec().iter().any(|v| v.abs() > 0.0))
1299        });
1300        assert!(grads_exist, "LSTM parameters should have gradients");
1301    }
1302
1303    #[test]
1304    fn test_lstm_different_sequence_lengths() {
1305        let lstm = LSTM::new(4, 8, 1);
1306
1307        // Short sequence
1308        let short = Variable::new(
1309            Tensor::from_vec(vec![1.0; 2 * 4], &[1, 2, 4]).unwrap(),
1310            false,
1311        );
1312        let out_short = lstm.forward(&short);
1313        assert_eq!(out_short.shape(), vec![1, 2, 8]);
1314
1315        // Long sequence
1316        let long = Variable::new(
1317            Tensor::from_vec(vec![1.0; 20 * 4], &[1, 20, 4]).unwrap(),
1318            false,
1319        );
1320        let out_long = lstm.forward(&long);
1321        assert_eq!(out_long.shape(), vec![1, 20, 8]);
1322    }
1323
1324    #[test]
1325    fn test_lstm_parameters_count() {
1326        // LSTM has 4 gates (i, f, g, o), each with input and hidden weights + biases
1327        // Per layer: 4 * (input_size * hidden_size + hidden_size * hidden_size + 2 * hidden_size)
1328        let lstm = LSTM::new(10, 20, 1);
1329        let n = lstm.parameters().iter().map(|p| p.numel()).sum::<usize>();
1330        // Expected: 4 * (10*20 + 20*20 + 20 + 20) = 4 * (200 + 400 + 40) = 2560
1331        assert!(n > 0, "LSTM should have parameters");
1332    }
1333
1334    // =========================================================================
1335    // GRU Comprehensive
1336    // =========================================================================
1337
1338    #[test]
1339    fn test_gru_cell_forward_step() {
1340        let cell = GRUCell::new(8, 16);
1341        assert_eq!(cell.input_size(), 8);
1342        assert_eq!(cell.hidden_size(), 16);
1343
1344        let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1345        let hidden = Variable::new(
1346            Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1347            false,
1348        );
1349        let output = cell.forward_step(&input, &hidden);
1350        assert_eq!(output.shape(), vec![2, 16]);
1351    }
1352
1353    #[test]
1354    fn test_gru_multi_layer() {
1355        let gru = GRU::new(8, 16, 2);
1356        assert_eq!(gru.num_layers(), 2);
1357        assert_eq!(gru.hidden_size(), 16);
1358
1359        let input = Variable::new(
1360            Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1361            false,
1362        );
1363        let output = gru.forward(&input);
1364        assert_eq!(output.shape(), vec![2, 5, 16]);
1365    }
1366
1367    #[test]
1368    fn test_gru_forward_mean() {
1369        let gru = GRU::new(4, 8, 1);
1370        let input = Variable::new(
1371            Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1372            false,
1373        );
1374        let mean_out = gru.forward_mean(&input);
1375        // forward_mean averages over time: [B, T, H] → [B, H]
1376        assert_eq!(mean_out.shape(), vec![2, 8]);
1377    }
1378
1379    #[test]
1380    fn test_gru_forward_last() {
1381        let gru = GRU::new(4, 8, 1);
1382        let input = Variable::new(
1383            Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1384            false,
1385        );
1386        let last_out = gru.forward_last(&input);
1387        // forward_last returns only last timestep: [B, T, H] → [B, H]
1388        assert_eq!(last_out.shape(), vec![2, 8]);
1389    }
1390
1391    #[test]
1392    fn test_gru_gradient_flow_to_input() {
1393        let gru = GRU::new(4, 8, 1);
1394        let input = Variable::new(
1395            Tensor::from_vec(vec![0.5; 3 * 4], &[1, 3, 4]).unwrap(),
1396            true,
1397        );
1398        let output = gru.forward(&input);
1399        output.sum().backward();
1400
1401        let grad = input
1402            .grad()
1403            .expect("Input should have gradient through GRU");
1404        assert_eq!(grad.shape(), &[1, 3, 4]);
1405        assert!(
1406            grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1407            "GRU should propagate gradients"
1408        );
1409    }
1410
1411    #[test]
1412    fn test_gru_hidden_state_evolves() {
1413        let gru = GRU::new(4, 8, 1);
1414        let input = Variable::new(
1415            Tensor::from_vec(vec![1.0; 5 * 4], &[1, 5, 4]).unwrap(),
1416            false,
1417        );
1418        let output = gru.forward(&input);
1419        let out_vec = output.data().to_vec();
1420
1421        // Hidden states at different timesteps should differ
1422        let t0 = &out_vec[0..8];
1423        let t4 = &out_vec[4 * 8..5 * 8];
1424        let diff: f32 = t0.iter().zip(t4.iter()).map(|(a, b)| (a - b).abs()).sum();
1425        assert!(
1426            diff > 1e-6,
1427            "GRU hidden state should evolve over time, diff={}",
1428            diff
1429        );
1430    }
1431
1432    // =========================================================================
1433    // RNN Basic
1434    // =========================================================================
1435
1436    #[test]
1437    fn test_rnn_cell_gradient_flow() {
1438        let cell = RNNCell::new(4, 8);
1439        let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), true);
1440        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 8], &[1, 8]).unwrap(), false);
1441        let out = cell.forward_step(&input, &hidden);
1442        out.sum().backward();
1443
1444        let grad = input.grad().expect("RNNCell should propagate gradients");
1445        assert_eq!(grad.shape(), &[1, 4]);
1446    }
1447
1448    #[test]
1449    fn test_rnn_multi_layer() {
1450        let rnn = RNN::with_options(8, 16, 3, true); // 3 layers, bias
1451        let input = Variable::new(
1452            Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1453            false,
1454        );
1455        let output = rnn.forward(&input);
1456        assert_eq!(output.shape(), vec![2, 5, 16]);
1457    }
1458
1459    // =========================================================================
1460    // Numerical Stability
1461    // =========================================================================
1462
1463    #[test]
1464    fn test_lstm_outputs_are_bounded() {
1465        // LSTM should produce bounded outputs (tanh output gate)
1466        let lstm = LSTM::new(4, 8, 1);
1467        let input = Variable::new(
1468            Tensor::from_vec(vec![100.0; 10 * 4], &[1, 10, 4]).unwrap(),
1469            false,
1470        );
1471        let output = lstm.forward(&input);
1472        let out_vec = output.data().to_vec();
1473
1474        // All outputs should be in [-1, 1] range (tanh bounded)
1475        for v in &out_vec {
1476            assert!(v.is_finite(), "LSTM output should be finite, got {}", v);
1477            assert!(
1478                v.abs() <= 1.0 + 1e-5,
1479                "LSTM output should be bounded by tanh: got {}",
1480                v
1481            );
1482        }
1483    }
1484
1485    #[test]
1486    fn test_gru_outputs_finite_with_large_input() {
1487        let gru = GRU::new(4, 8, 1);
1488        let input = Variable::new(
1489            Tensor::from_vec(vec![50.0; 5 * 4], &[1, 5, 4]).unwrap(),
1490            false,
1491        );
1492        let output = gru.forward(&input);
1493        assert!(
1494            output.data().to_vec().iter().all(|v| v.is_finite()),
1495            "GRU should produce finite outputs for large inputs"
1496        );
1497    }
1498}