Skip to main content

axonml_nn/layers/
rnn.rs

1//! Recurrent Neural Network Layers - RNN, LSTM, GRU
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/rnn.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::init::{xavier_uniform, zeros};
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25// =============================================================================
26// RNNCell
27// =============================================================================
28
29/// A single RNN cell.
30///
31/// h' = tanh(W_ih * x + b_ih + W_hh * h + b_hh)
32pub struct RNNCell {
33    /// Input-hidden weights.
34    pub weight_ih: Parameter,
35    /// Hidden-hidden weights.
36    pub weight_hh: Parameter,
37    /// Input-hidden bias.
38    pub bias_ih: Parameter,
39    /// Hidden-hidden bias.
40    pub bias_hh: Parameter,
41    /// Input size.
42    input_size: usize,
43    /// Hidden size.
44    hidden_size: usize,
45}
46
47impl RNNCell {
48    /// Creates a new RNNCell.
49    pub fn new(input_size: usize, hidden_size: usize) -> Self {
50        Self {
51            weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
52            weight_hh: Parameter::named(
53                "weight_hh",
54                xavier_uniform(hidden_size, hidden_size),
55                true,
56            ),
57            bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
58            bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
59            input_size,
60            hidden_size,
61        }
62    }
63
64    /// Returns the expected input size.
65    pub fn input_size(&self) -> usize {
66        self.input_size
67    }
68
69    /// Returns the hidden state size.
70    pub fn hidden_size(&self) -> usize {
71        self.hidden_size
72    }
73
74    /// Forward pass for a single time step.
75    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
76        let input_features = input.data().shape().last().copied().unwrap_or(0);
77        assert_eq!(
78            input_features, self.input_size,
79            "RNNCell: expected input size {}, got {}",
80            self.input_size, input_features
81        );
82        // x @ W_ih^T + b_ih
83        let weight_ih = self.weight_ih.variable();
84        let weight_ih_t = weight_ih.transpose(0, 1);
85        let ih = input.matmul(&weight_ih_t);
86        let bias_ih = self.bias_ih.variable();
87        let ih = ih.add_var(&bias_ih);
88
89        // h @ W_hh^T + b_hh
90        let weight_hh = self.weight_hh.variable();
91        let weight_hh_t = weight_hh.transpose(0, 1);
92        let hh = hidden.matmul(&weight_hh_t);
93        let bias_hh = self.bias_hh.variable();
94        let hh = hh.add_var(&bias_hh);
95
96        // tanh(ih + hh)
97        ih.add_var(&hh).tanh()
98    }
99}
100
101impl Module for RNNCell {
102    fn forward(&self, input: &Variable) -> Variable {
103        // Initialize hidden state to zeros
104        let batch_size = input.shape()[0];
105        let hidden = Variable::new(
106            zeros(&[batch_size, self.hidden_size]),
107            input.requires_grad(),
108        );
109        self.forward_step(input, &hidden)
110    }
111
112    fn parameters(&self) -> Vec<Parameter> {
113        vec![
114            self.weight_ih.clone(),
115            self.weight_hh.clone(),
116            self.bias_ih.clone(),
117            self.bias_hh.clone(),
118        ]
119    }
120
121    fn named_parameters(&self) -> HashMap<String, Parameter> {
122        let mut params = HashMap::new();
123        params.insert("weight_ih".to_string(), self.weight_ih.clone());
124        params.insert("weight_hh".to_string(), self.weight_hh.clone());
125        params.insert("bias_ih".to_string(), self.bias_ih.clone());
126        params.insert("bias_hh".to_string(), self.bias_hh.clone());
127        params
128    }
129
130    fn name(&self) -> &'static str {
131        "RNNCell"
132    }
133}
134
135// =============================================================================
136// RNN
137// =============================================================================
138
139/// Multi-layer RNN.
140///
141/// Processes sequences through stacked RNN layers.
142pub struct RNN {
143    /// RNN cells for each layer.
144    cells: Vec<RNNCell>,
145    /// Input size.
146    _input_size: usize,
147    /// Hidden size.
148    hidden_size: usize,
149    /// Number of layers.
150    num_layers: usize,
151    /// Batch first flag.
152    batch_first: bool,
153}
154
155impl RNN {
156    /// Creates a new multi-layer RNN.
157    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
158        Self::with_options(input_size, hidden_size, num_layers, true)
159    }
160
161    /// Creates an RNN with all options.
162    pub fn with_options(
163        input_size: usize,
164        hidden_size: usize,
165        num_layers: usize,
166        batch_first: bool,
167    ) -> Self {
168        let mut cells = Vec::with_capacity(num_layers);
169
170        // First layer takes input_size
171        cells.push(RNNCell::new(input_size, hidden_size));
172
173        // Subsequent layers take hidden_size
174        for _ in 1..num_layers {
175            cells.push(RNNCell::new(hidden_size, hidden_size));
176        }
177
178        Self {
179            cells,
180            _input_size: input_size,
181            hidden_size,
182            num_layers,
183            batch_first,
184        }
185    }
186}
187
188impl Module for RNN {
189    fn forward(&self, input: &Variable) -> Variable {
190        let shape = input.shape();
191        let (batch_size, seq_len, input_features) = if self.batch_first {
192            (shape[0], shape[1], shape[2])
193        } else {
194            (shape[1], shape[0], shape[2])
195        };
196
197        // Initialize hidden states
198        let mut hiddens: Vec<Variable> = (0..self.num_layers)
199            .map(|_| {
200                Variable::new(
201                    zeros(&[batch_size, self.hidden_size]),
202                    input.requires_grad(),
203                )
204            })
205            .collect();
206
207        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
208        let cell0 = &self.cells[0];
209        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
210        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
211        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
212        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, self.hidden_size]);
213
214        // Hoist weight transposes out of the per-timestep loop
215        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
216        let bias_hh_0 = cell0.bias_hh.variable();
217
218        let mut outputs = Vec::with_capacity(seq_len);
219
220        for t in 0..seq_len {
221            // Layer 0: use pre-computed ih projection + hoisted weight transpose
222            let ih_t = ih_all_3d.select(1, t);
223            let hh = hiddens[0].matmul(&w_hh_t_0).add_var(&bias_hh_0);
224            hiddens[0] = ih_t.add_var(&hh).tanh();
225
226            // Subsequent layers
227            for l in 1..self.num_layers {
228                let layer_input = hiddens[l - 1].clone();
229                hiddens[l] = self.cells[l].forward_step(&layer_input, &hiddens[l]);
230            }
231
232            outputs.push(hiddens[self.num_layers - 1].clone());
233        }
234
235        // Stack outputs using graph-tracked cat (unsqueeze + cat along time dim)
236        let time_dim = usize::from(self.batch_first);
237        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
238        let refs: Vec<&Variable> = unsqueezed.iter().collect();
239        Variable::cat(&refs, time_dim)
240    }
241
242    fn parameters(&self) -> Vec<Parameter> {
243        self.cells.iter().flat_map(|c| c.parameters()).collect()
244    }
245
246    fn name(&self) -> &'static str {
247        "RNN"
248    }
249}
250
251// =============================================================================
252// LSTMCell
253// =============================================================================
254
255/// A single LSTM cell.
256pub struct LSTMCell {
257    /// Input-hidden weights for all gates.
258    pub weight_ih: Parameter,
259    /// Hidden-hidden weights for all gates.
260    pub weight_hh: Parameter,
261    /// Input-hidden bias for all gates.
262    pub bias_ih: Parameter,
263    /// Hidden-hidden bias for all gates.
264    pub bias_hh: Parameter,
265    /// Input size.
266    input_size: usize,
267    /// Hidden size.
268    hidden_size: usize,
269}
270
271impl LSTMCell {
272    /// Creates a new LSTMCell.
273    pub fn new(input_size: usize, hidden_size: usize) -> Self {
274        // LSTM has 4 gates, so weight size is 4*hidden_size
275        Self {
276            weight_ih: Parameter::named(
277                "weight_ih",
278                xavier_uniform(input_size, 4 * hidden_size),
279                true,
280            ),
281            weight_hh: Parameter::named(
282                "weight_hh",
283                xavier_uniform(hidden_size, 4 * hidden_size),
284                true,
285            ),
286            bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
287            bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
288            input_size,
289            hidden_size,
290        }
291    }
292
293    /// Returns the expected input size.
294    pub fn input_size(&self) -> usize {
295        self.input_size
296    }
297
298    /// Returns the hidden state size.
299    pub fn hidden_size(&self) -> usize {
300        self.hidden_size
301    }
302
303    /// Forward pass returning (h', c').
304    pub fn forward_step(
305        &self,
306        input: &Variable,
307        hx: &(Variable, Variable),
308    ) -> (Variable, Variable) {
309        let input_features = input.data().shape().last().copied().unwrap_or(0);
310        assert_eq!(
311            input_features, self.input_size,
312            "LSTMCell: expected input size {}, got {}",
313            self.input_size, input_features
314        );
315
316        let (h, c) = hx;
317
318        // Compute all gates at once (x @ W^T + b)
319        let weight_ih = self.weight_ih.variable();
320        let weight_ih_t = weight_ih.transpose(0, 1);
321        let ih = input.matmul(&weight_ih_t);
322        let bias_ih = self.bias_ih.variable();
323        let ih = ih.add_var(&bias_ih);
324
325        let weight_hh = self.weight_hh.variable();
326        let weight_hh_t = weight_hh.transpose(0, 1);
327        let hh = h.matmul(&weight_hh_t);
328        let bias_hh = self.bias_hh.variable();
329        let hh = hh.add_var(&bias_hh);
330
331        let gates = ih.add_var(&hh);
332        let hs = self.hidden_size;
333
334        // Split into 4 gates using narrow (preserves gradient flow)
335        let i = gates.narrow(1, 0, hs).sigmoid();
336        let f = gates.narrow(1, hs, hs).sigmoid();
337        let g = gates.narrow(1, 2 * hs, hs).tanh();
338        let o = gates.narrow(1, 3 * hs, hs).sigmoid();
339
340        // c' = f * c + i * g
341        let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
342
343        // h' = o * tanh(c')
344        let h_new = o.mul_var(&c_new.tanh());
345
346        (h_new, c_new)
347    }
348}
349
350impl Module for LSTMCell {
351    fn forward(&self, input: &Variable) -> Variable {
352        let batch_size = input.shape()[0];
353        let h = Variable::new(
354            zeros(&[batch_size, self.hidden_size]),
355            input.requires_grad(),
356        );
357        let c = Variable::new(
358            zeros(&[batch_size, self.hidden_size]),
359            input.requires_grad(),
360        );
361        let (h_new, _) = self.forward_step(input, &(h, c));
362        h_new
363    }
364
365    fn parameters(&self) -> Vec<Parameter> {
366        vec![
367            self.weight_ih.clone(),
368            self.weight_hh.clone(),
369            self.bias_ih.clone(),
370            self.bias_hh.clone(),
371        ]
372    }
373
374    fn named_parameters(&self) -> HashMap<String, Parameter> {
375        let mut params = HashMap::new();
376        params.insert("weight_ih".to_string(), self.weight_ih.clone());
377        params.insert("weight_hh".to_string(), self.weight_hh.clone());
378        params.insert("bias_ih".to_string(), self.bias_ih.clone());
379        params.insert("bias_hh".to_string(), self.bias_hh.clone());
380        params
381    }
382
383    fn name(&self) -> &'static str {
384        "LSTMCell"
385    }
386}
387
388// =============================================================================
389// LSTM
390// =============================================================================
391
392/// Multi-layer LSTM.
393pub struct LSTM {
394    /// LSTM cells for each layer.
395    cells: Vec<LSTMCell>,
396    /// Input size.
397    input_size: usize,
398    /// Hidden size.
399    hidden_size: usize,
400    /// Number of layers.
401    num_layers: usize,
402    /// Batch first flag.
403    batch_first: bool,
404}
405
406impl LSTM {
407    /// Creates a new multi-layer LSTM.
408    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
409        Self::with_options(input_size, hidden_size, num_layers, true)
410    }
411
412    /// Creates an LSTM with all options.
413    pub fn with_options(
414        input_size: usize,
415        hidden_size: usize,
416        num_layers: usize,
417        batch_first: bool,
418    ) -> Self {
419        let mut cells = Vec::with_capacity(num_layers);
420        cells.push(LSTMCell::new(input_size, hidden_size));
421        for _ in 1..num_layers {
422            cells.push(LSTMCell::new(hidden_size, hidden_size));
423        }
424
425        Self {
426            cells,
427            input_size,
428            hidden_size,
429            num_layers,
430            batch_first,
431        }
432    }
433
434    /// Returns the expected input size.
435    pub fn input_size(&self) -> usize {
436        self.input_size
437    }
438
439    /// Returns the hidden state size.
440    pub fn hidden_size(&self) -> usize {
441        self.hidden_size
442    }
443
444    /// Returns the number of layers.
445    pub fn num_layers(&self) -> usize {
446        self.num_layers
447    }
448}
449
450impl Module for LSTM {
451    fn forward(&self, input: &Variable) -> Variable {
452        let shape = input.shape();
453        let (batch_size, seq_len, input_features) = if self.batch_first {
454            (shape[0], shape[1], shape[2])
455        } else {
456            (shape[1], shape[0], shape[2])
457        };
458
459        let lstm_input_device = input.data().device();
460        #[cfg(feature = "cuda")]
461        let lstm_on_gpu = lstm_input_device.is_gpu();
462        #[cfg(not(feature = "cuda"))]
463        let lstm_on_gpu = false;
464
465        let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
466            .map(|_| {
467                let make_h = || {
468                    let h_cpu = zeros(&[batch_size, self.hidden_size]);
469                    let h_tensor = if lstm_on_gpu {
470                        h_cpu
471                            .to_device(lstm_input_device)
472                            .expect("LSTM: failed to move hidden state to GPU")
473                    } else {
474                        h_cpu
475                    };
476                    Variable::new(h_tensor, input.requires_grad())
477                };
478                (make_h(), make_h())
479            })
480            .collect();
481
482        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
483        // input: [batch, seq, features] -> reshaped to [batch*seq, features]
484        // ih_all: [batch*seq, 4*hidden] = input_2d @ W_ih^T + bias_ih
485        // Note: matmul auto-dispatches to cuBLAS GEMM when tensors are on GPU
486        let cell0 = &self.cells[0];
487        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
488        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
489        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
490        // ih_all_3d: [batch, seq, 4*hidden]
491        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 4 * self.hidden_size]);
492
493        // Hoist weight transpose + bias out of the per-timestep loop
494        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
495        let bias_hh_0 = cell0.bias_hh.variable();
496
497        let mut outputs = Vec::with_capacity(seq_len);
498
499        // Check if we're on GPU for fused gate kernel path
500        #[cfg(feature = "cuda")]
501        let on_gpu = input.data().device().is_gpu();
502        #[cfg(not(feature = "cuda"))]
503        let on_gpu = false;
504
505        for t in 0..seq_len {
506            // Layer 0: use pre-computed ih projection + hoisted weight transpose
507            let ih_t = ih_all_3d.select(1, t);
508            let (h, c) = &states[0];
509
510            // h @ W_hh^T + bias_hh (cuBLAS on GPU, matrixmultiply on CPU)
511            let hh = h.matmul(&w_hh_t_0).add_var(&bias_hh_0);
512
513            // Combined gates = ih + hh
514            let gates = ih_t.add_var(&hh);
515
516            if on_gpu {
517                // GPU path: fused LSTM gate kernel (1 launch vs ~14 separate ops)
518                // gates [batch, 4*hidden], c [batch, hidden] → h_new, c_new [batch, hidden]
519                #[cfg(feature = "cuda")]
520                {
521                    let hs = self.hidden_size;
522                    let gates_data = gates.data();
523                    let c_data = c.data();
524
525                    if let Some((h_tensor, c_tensor)) = gates_data.lstm_gates_fused(&c_data, hs) {
526                        // Save forward state for backward
527                        let saved_gates = gates_data.clone();
528                        let saved_c_prev = c_data.clone();
529                        let saved_c_new = c_tensor.clone();
530
531                        // Create proper backward that calls LSTM backward kernel
532                        let backward_fn = axonml_autograd::LstmGatesBackward::new(
533                            gates.grad_fn().cloned(),
534                            c.grad_fn().cloned(),
535                            saved_gates,
536                            saved_c_prev,
537                            saved_c_new,
538                            hs,
539                        );
540                        let grad_fn = axonml_autograd::GradFn::new(backward_fn);
541
542                        let fused_requires_grad = gates.requires_grad() || c.requires_grad();
543                        let h_new = Variable::from_operation(
544                            h_tensor,
545                            grad_fn.clone(),
546                            fused_requires_grad,
547                        );
548                        let c_new =
549                            Variable::from_operation(c_tensor, grad_fn, fused_requires_grad);
550                        states[0] = (h_new, c_new);
551                    }
552                }
553            } else {
554                // CPU path: individual ops (each autograd-tracked)
555                let hs = self.hidden_size;
556                let i_gate = gates.narrow(1, 0, hs).sigmoid();
557                let f_gate = gates.narrow(1, hs, hs).sigmoid();
558                let g_gate = gates.narrow(1, 2 * hs, hs).tanh();
559                let o_gate = gates.narrow(1, 3 * hs, hs).sigmoid();
560                let c_new = f_gate.mul_var(c).add_var(&i_gate.mul_var(&g_gate));
561                let h_new = o_gate.mul_var(&c_new.tanh());
562                states[0] = (h_new, c_new);
563            }
564
565            // Subsequent layers use the regular cell forward_step
566            for l in 1..self.num_layers {
567                let layer_input = states[l - 1].0.clone();
568                states[l] = self.cells[l].forward_step(&layer_input, &states[l]);
569            }
570
571            outputs.push(states[self.num_layers - 1].0.clone());
572        }
573
574        // Stack outputs along the time dimension
575        let time_dim = usize::from(self.batch_first);
576        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
577        let refs: Vec<&Variable> = unsqueezed.iter().collect();
578        Variable::cat(&refs, time_dim)
579    }
580
581    fn parameters(&self) -> Vec<Parameter> {
582        self.cells.iter().flat_map(|c| c.parameters()).collect()
583    }
584
585    fn named_parameters(&self) -> HashMap<String, Parameter> {
586        let mut params = HashMap::new();
587        if self.cells.len() == 1 {
588            // Single layer: expose directly without cell index prefix
589            for (n, p) in self.cells[0].named_parameters() {
590                params.insert(n, p);
591            }
592        } else {
593            for (i, cell) in self.cells.iter().enumerate() {
594                for (n, p) in cell.named_parameters() {
595                    params.insert(format!("cells.{i}.{n}"), p);
596                }
597            }
598        }
599        params
600    }
601
602    fn name(&self) -> &'static str {
603        "LSTM"
604    }
605}
606
607// =============================================================================
608// GRUCell and GRU
609// =============================================================================
610
611/// A single GRU cell.
612///
613/// h' = (1 - z) * n + z * h
614/// where:
615///   r = sigmoid(W_ir * x + b_ir + W_hr * h + b_hr)  (reset gate)
616///   z = sigmoid(W_iz * x + b_iz + W_hz * h + b_hz)  (update gate)
617///   n = tanh(W_in * x + b_in + r * (W_hn * h + b_hn))  (new gate)
618pub struct GRUCell {
619    /// Input-hidden weights for all gates (reset, update, new).
620    pub weight_ih: Parameter,
621    /// Hidden-hidden weights for all gates (reset, update, new).
622    pub weight_hh: Parameter,
623    /// Input-hidden bias for all gates.
624    pub bias_ih: Parameter,
625    /// Hidden-hidden bias for all gates.
626    pub bias_hh: Parameter,
627    /// Input size.
628    input_size: usize,
629    /// Hidden size.
630    hidden_size: usize,
631}
632
633impl GRUCell {
634    /// Creates a new GRU cell.
635    pub fn new(input_size: usize, hidden_size: usize) -> Self {
636        Self {
637            weight_ih: Parameter::named(
638                "weight_ih",
639                xavier_uniform(input_size, 3 * hidden_size),
640                true,
641            ),
642            weight_hh: Parameter::named(
643                "weight_hh",
644                xavier_uniform(hidden_size, 3 * hidden_size),
645                true,
646            ),
647            bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
648            bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
649            input_size,
650            hidden_size,
651        }
652    }
653
654    /// Returns the expected input size.
655    pub fn input_size(&self) -> usize {
656        self.input_size
657    }
658
659    /// Returns the hidden state size.
660    pub fn hidden_size(&self) -> usize {
661        self.hidden_size
662    }
663}
664
665impl GRUCell {
666    /// Forward pass for a single time step with explicit hidden state.
667    ///
668    /// GRU equations:
669    /// r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)
670    /// z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)
671    /// n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
672    /// h_t = (1 - z_t) * n_t + z_t * h_{t-1}
673    ///
674    /// All computations use Variable operations for proper gradient flow.
675    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
676        let _batch_size = input.shape()[0];
677        let hidden_size = self.hidden_size;
678
679        // Get weight matrices
680        let weight_ih = self.weight_ih.variable();
681        let weight_hh = self.weight_hh.variable();
682        let bias_ih = self.bias_ih.variable();
683        let bias_hh = self.bias_hh.variable();
684
685        // Compute input transformation: x @ W_ih^T + b_ih
686        // Shape: [batch, 3*hidden_size]
687        let weight_ih_t = weight_ih.transpose(0, 1);
688        let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
689
690        // Compute hidden transformation: h @ W_hh^T + b_hh
691        // Shape: [batch, 3*hidden_size]
692        let weight_hh_t = weight_hh.transpose(0, 1);
693        let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
694
695        // Use narrow to split into gates (preserves gradient flow)
696        // Each gate slice: [batch, hidden_size]
697        let ih_r = ih.narrow(1, 0, hidden_size);
698        let ih_z = ih.narrow(1, hidden_size, hidden_size);
699        let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
700
701        let hh_r = hh.narrow(1, 0, hidden_size);
702        let hh_z = hh.narrow(1, hidden_size, hidden_size);
703        let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
704
705        // Compute gates using Variable operations for gradient flow
706        // r = sigmoid(ih_r + hh_r)
707        let r = ih_r.add_var(&hh_r).sigmoid();
708
709        // z = sigmoid(ih_z + hh_z)
710        let z = ih_z.add_var(&hh_z).sigmoid();
711
712        // n = tanh(ih_n + r * hh_n)
713        let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
714
715        // h_new = (1 - z) * n + z * h_prev
716        // Rewritten as: n + z * (h_prev - n)  to avoid allocating a ones tensor
717        let h_minus_n = hidden.sub_var(&n);
718        n.add_var(&z.mul_var(&h_minus_n))
719    }
720}
721
722impl Module for GRUCell {
723    fn forward(&self, input: &Variable) -> Variable {
724        let batch_size = input.shape()[0];
725
726        // Initialize hidden state to zeros
727        let hidden = Variable::new(
728            zeros(&[batch_size, self.hidden_size]),
729            input.requires_grad(),
730        );
731
732        self.forward_step(input, &hidden)
733    }
734
735    fn parameters(&self) -> Vec<Parameter> {
736        vec![
737            self.weight_ih.clone(),
738            self.weight_hh.clone(),
739            self.bias_ih.clone(),
740            self.bias_hh.clone(),
741        ]
742    }
743
744    fn named_parameters(&self) -> HashMap<String, Parameter> {
745        let mut params = HashMap::new();
746        params.insert("weight_ih".to_string(), self.weight_ih.clone());
747        params.insert("weight_hh".to_string(), self.weight_hh.clone());
748        params.insert("bias_ih".to_string(), self.bias_ih.clone());
749        params.insert("bias_hh".to_string(), self.bias_hh.clone());
750        params
751    }
752
753    fn name(&self) -> &'static str {
754        "GRUCell"
755    }
756}
757
758/// Multi-layer GRU.
759pub struct GRU {
760    /// GRU cells for each layer.
761    cells: Vec<GRUCell>,
762    /// Hidden state size.
763    hidden_size: usize,
764    /// Number of layers.
765    num_layers: usize,
766    /// If true, input is (batch, seq, features), else (seq, batch, features).
767    batch_first: bool,
768}
769
770impl GRU {
771    /// Creates a new multi-layer GRU.
772    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
773        let mut cells = Vec::with_capacity(num_layers);
774        cells.push(GRUCell::new(input_size, hidden_size));
775        for _ in 1..num_layers {
776            cells.push(GRUCell::new(hidden_size, hidden_size));
777        }
778        Self {
779            cells,
780            hidden_size,
781            num_layers,
782            batch_first: true,
783        }
784    }
785
786    /// Returns the hidden state size.
787    pub fn hidden_size(&self) -> usize {
788        self.hidden_size
789    }
790
791    /// Returns the number of layers.
792    pub fn num_layers(&self) -> usize {
793        self.num_layers
794    }
795}
796
797impl Module for GRU {
798    fn forward(&self, input: &Variable) -> Variable {
799        let shape = input.shape();
800        let (batch_size, seq_len, input_features) = if self.batch_first {
801            (shape[0], shape[1], shape[2])
802        } else {
803            (shape[1], shape[0], shape[2])
804        };
805
806        // Check if we're on GPU for fused gate kernel path
807        #[cfg(feature = "cuda")]
808        let on_gpu = input.data().device().is_gpu();
809        #[cfg(not(feature = "cuda"))]
810        let on_gpu = false;
811
812        let input_device = input.data().device();
813
814        // Initialize hidden states for all layers as Variables (with gradients)
815        // Move to the same device as input so GPU fused kernels receive GPU tensors.
816        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
817            .map(|_| {
818                let h_cpu = zeros(&[batch_size, self.hidden_size]);
819                let h_tensor = if on_gpu {
820                    h_cpu
821                        .to_device(input_device)
822                        .expect("GRU: failed to move hidden state to GPU")
823                } else {
824                    h_cpu
825                };
826                Variable::new(h_tensor, input.requires_grad())
827            })
828            .collect();
829
830        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
831        // One big matmul instead of seq_len small ones
832        let cell0 = &self.cells[0];
833        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
834        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
835        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
836        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
837
838        // Hoist weight transpose + bias out of the per-timestep loop
839        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
840        let bias_hh_0 = cell0.bias_hh.variable();
841
842        let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
843
844        for t in 0..seq_len {
845            // Layer 0: use pre-computed ih projection + hoisted weight transpose
846            let ih_t = ih_all_3d.select(1, t);
847            let hidden = &hidden_states[0];
848            let hs = self.hidden_size;
849
850            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
851
852            if on_gpu {
853                // GPU path: fused GRU gate kernel (1 launch vs ~12 separate ops)
854                // ih_t [batch, 3*hidden], hh [batch, 3*hidden], hidden [batch, hidden] → h_new [batch, hidden]
855                #[cfg(feature = "cuda")]
856                {
857                    let ih_data = ih_t.data();
858                    let hh_data = hh.data();
859                    let h_data = hidden.data();
860
861                    if let Some(h_tensor) = ih_data.gru_gates_fused(&hh_data, &h_data, hs) {
862                        // Save forward state for backward
863                        let saved_ih = ih_data.clone();
864                        let saved_hh = hh_data.clone();
865                        let saved_h_prev = h_data.clone();
866
867                        // Create proper backward that calls GRU backward kernel
868                        let backward_fn = axonml_autograd::GruGatesBackward::new(
869                            ih_t.grad_fn().cloned(),
870                            hh.grad_fn().cloned(),
871                            hidden.grad_fn().cloned(),
872                            saved_ih,
873                            saved_hh,
874                            saved_h_prev,
875                            hs,
876                        );
877                        let grad_fn = axonml_autograd::GradFn::new(backward_fn);
878
879                        // Use requires_grad=true if ANY input to the fused op
880                        // requires grad — the GRU parameters (w_ih, w_hh, bias)
881                        // always require grad during training, so ih_t and hh
882                        // will have requires_grad=true even when the raw input
883                        // Variable does not.
884                        let fused_requires_grad =
885                            ih_t.requires_grad() || hh.requires_grad() || hidden.requires_grad();
886                        let h_new =
887                            Variable::from_operation(h_tensor, grad_fn, fused_requires_grad);
888                        hidden_states[0] = h_new;
889                    }
890                }
891            } else {
892                // CPU path: individual ops (each autograd-tracked)
893                let ih_r = ih_t.narrow(1, 0, hs);
894                let ih_z = ih_t.narrow(1, hs, hs);
895                let ih_n = ih_t.narrow(1, 2 * hs, hs);
896                let hh_r = hh.narrow(1, 0, hs);
897                let hh_z = hh.narrow(1, hs, hs);
898                let hh_n = hh.narrow(1, 2 * hs, hs);
899
900                let r = ih_r.add_var(&hh_r).sigmoid();
901                let z = ih_z.add_var(&hh_z).sigmoid();
902                let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
903                let h_minus_n = hidden.sub_var(&n);
904                let h_new = n.add_var(&z.mul_var(&h_minus_n));
905                hidden_states[0] = h_new;
906            }
907
908            // Subsequent layers use the regular cell forward_step
909            let mut layer_output = hidden_states[0].clone();
910            for l in 1..self.num_layers {
911                let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
912                hidden_states[l] = new_hidden.clone();
913                layer_output = new_hidden;
914            }
915
916            output_vars.push(layer_output);
917        }
918
919        // Stack outputs along the time dimension
920        self.stack_outputs(&output_vars, batch_size, seq_len)
921    }
922
923    fn parameters(&self) -> Vec<Parameter> {
924        self.cells.iter().flat_map(|c| c.parameters()).collect()
925    }
926
927    fn named_parameters(&self) -> HashMap<String, Parameter> {
928        let mut params = HashMap::new();
929        if self.cells.len() == 1 {
930            for (n, p) in self.cells[0].named_parameters() {
931                params.insert(n, p);
932            }
933        } else {
934            for (i, cell) in self.cells.iter().enumerate() {
935                for (n, p) in cell.named_parameters() {
936                    params.insert(format!("cells.{i}.{n}"), p);
937                }
938            }
939        }
940        params
941    }
942
943    fn name(&self) -> &'static str {
944        "GRU"
945    }
946}
947
948impl GRU {
949    /// Forward pass that returns the mean of all hidden states.
950    /// This is equivalent to processing then mean pooling, but with proper gradient flow.
951    pub fn forward_mean(&self, input: &Variable) -> Variable {
952        let shape = input.shape();
953        let (batch_size, seq_len, input_features) = if self.batch_first {
954            (shape[0], shape[1], shape[2])
955        } else {
956            (shape[1], shape[0], shape[2])
957        };
958
959        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
960            .map(|_| {
961                Variable::new(
962                    zeros(&[batch_size, self.hidden_size]),
963                    input.requires_grad(),
964                )
965            })
966            .collect();
967
968        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
969        let cell0 = &self.cells[0];
970        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
971        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
972        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
973        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
974
975        // Hoist weight transpose + bias out of per-timestep loop
976        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
977        let bias_hh_0 = cell0.bias_hh.variable();
978
979        let mut output_sum: Option<Variable> = None;
980        let hs = self.hidden_size;
981
982        for t in 0..seq_len {
983            // Layer 0: use pre-computed ih projection + hoisted weight transpose
984            let ih_t = ih_all_3d.select(1, t);
985            let hidden = &hidden_states[0];
986            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
987
988            let ih_r = ih_t.narrow(1, 0, hs);
989            let ih_z = ih_t.narrow(1, hs, hs);
990            let ih_n = ih_t.narrow(1, 2 * hs, hs);
991            let hh_r = hh.narrow(1, 0, hs);
992            let hh_z = hh.narrow(1, hs, hs);
993            let hh_n = hh.narrow(1, 2 * hs, hs);
994
995            let r = ih_r.add_var(&hh_r).sigmoid();
996            let z = ih_z.add_var(&hh_z).sigmoid();
997            let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
998            let h_minus_n = hidden.sub_var(&n);
999            let h_new = n.add_var(&z.mul_var(&h_minus_n));
1000            hidden_states[0] = h_new.clone();
1001
1002            // Subsequent layers
1003            let mut layer_output = h_new;
1004            for l in 1..self.num_layers {
1005                let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
1006                hidden_states[l] = new_hidden.clone();
1007                layer_output = new_hidden;
1008            }
1009
1010            output_sum = Some(match output_sum {
1011                None => layer_output,
1012                Some(acc) => acc.add_var(&layer_output),
1013            });
1014        }
1015
1016        match output_sum {
1017            Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
1018            None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
1019        }
1020    }
1021
1022    /// Forward pass that returns the last hidden state.
1023    /// Good for sequence classification with proper gradient flow.
1024    pub fn forward_last(&self, input: &Variable) -> Variable {
1025        let shape = input.shape();
1026        let (batch_size, seq_len, input_features) = if self.batch_first {
1027            (shape[0], shape[1], shape[2])
1028        } else {
1029            (shape[1], shape[0], shape[2])
1030        };
1031
1032        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
1033            .map(|_| {
1034                Variable::new(
1035                    zeros(&[batch_size, self.hidden_size]),
1036                    input.requires_grad(),
1037                )
1038            })
1039            .collect();
1040
1041        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
1042        let cell0 = &self.cells[0];
1043        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
1044        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
1045        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
1046        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
1047
1048        // Hoist weight transpose + bias out of per-timestep loop
1049        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
1050        let bias_hh_0 = cell0.bias_hh.variable();
1051        let hs = self.hidden_size;
1052
1053        for t in 0..seq_len {
1054            // Layer 0: use pre-computed ih projection + hoisted weight transpose
1055            let ih_t = ih_all_3d.select(1, t);
1056            let hidden = &hidden_states[0];
1057            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
1058
1059            let ih_r = ih_t.narrow(1, 0, hs);
1060            let ih_z = ih_t.narrow(1, hs, hs);
1061            let ih_n = ih_t.narrow(1, 2 * hs, hs);
1062            let hh_r = hh.narrow(1, 0, hs);
1063            let hh_z = hh.narrow(1, hs, hs);
1064            let hh_n = hh.narrow(1, 2 * hs, hs);
1065
1066            let r = ih_r.add_var(&hh_r).sigmoid();
1067            let z = ih_z.add_var(&hh_z).sigmoid();
1068            let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1069            let h_minus_n = hidden.sub_var(&n);
1070            let h_new = n.add_var(&z.mul_var(&h_minus_n));
1071            hidden_states[0] = h_new.clone();
1072
1073            // Subsequent layers
1074            let mut layer_input = h_new;
1075
1076            for (layer_idx, cell) in self.cells.iter().enumerate().skip(1) {
1077                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
1078                hidden_states[layer_idx] = new_hidden.clone();
1079                layer_input = new_hidden;
1080            }
1081        }
1082
1083        // Return last hidden state from last layer
1084        hidden_states
1085            .pop()
1086            .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
1087    }
1088
1089    /// Stack output Variables into a single [batch, seq, hidden] tensor.
1090    /// Note: This creates a new tensor without gradient connections to individual timesteps.
1091    /// For gradient flow, use forward_mean() or forward_last() instead.
1092    fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
1093        if outputs.is_empty() {
1094            return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
1095        }
1096
1097        // Unsqueeze each (batch, hidden) → (batch, 1, hidden), then cat along dim=1
1098        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(1)).collect();
1099        let refs: Vec<&Variable> = unsqueezed.iter().collect();
1100        Variable::cat(&refs, 1)
1101    }
1102}
1103
1104// =============================================================================
1105// Tests
1106// =============================================================================
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111    use axonml_tensor::Tensor;
1112
1113    #[test]
1114    fn test_rnn_cell() {
1115        let cell = RNNCell::new(10, 20);
1116        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
1117        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
1118        let output = cell.forward_step(&input, &hidden);
1119        assert_eq!(output.shape(), vec![2, 20]);
1120    }
1121
1122    #[test]
1123    fn test_rnn() {
1124        let rnn = RNN::new(10, 20, 2);
1125        let input = Variable::new(
1126            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1127            false,
1128        );
1129        let output = rnn.forward(&input);
1130        assert_eq!(output.shape(), vec![2, 5, 20]);
1131    }
1132
1133    #[test]
1134    fn test_lstm() {
1135        let lstm = LSTM::new(10, 20, 1);
1136        let input = Variable::new(
1137            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1138            false,
1139        );
1140        let output = lstm.forward(&input);
1141        assert_eq!(output.shape(), vec![2, 5, 20]);
1142    }
1143
1144    #[test]
1145    fn test_gru_gradients_reach_parameters() {
1146        let gru = GRU::new(4, 8, 1);
1147        let input = Variable::new(
1148            Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1149            true,
1150        );
1151        let output = gru.forward(&input);
1152        println!(
1153            "Output shape: {:?}, requires_grad: {}",
1154            output.shape(),
1155            output.requires_grad()
1156        );
1157        let loss = output.sum();
1158        println!(
1159            "Loss: {:?}, requires_grad: {}",
1160            loss.data().to_vec(),
1161            loss.requires_grad()
1162        );
1163        loss.backward();
1164
1165        // Check input gradient
1166        println!(
1167            "Input grad: {:?}",
1168            input
1169                .grad()
1170                .map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>())
1171        );
1172
1173        let params = gru.parameters();
1174        println!("Number of parameters: {}", params.len());
1175        let mut has_grad = false;
1176        for (i, p) in params.iter().enumerate() {
1177            let grad = p.grad();
1178            match grad {
1179                Some(g) => {
1180                    let gv = g.to_vec();
1181                    let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1182                    println!(
1183                        "Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1184                        i,
1185                        p.shape(),
1186                        p.requires_grad(),
1187                        sum_abs
1188                    );
1189                    if sum_abs > 0.0 {
1190                        has_grad = true;
1191                    }
1192                }
1193                None => {
1194                    println!(
1195                        "Param {} shape {:?} requires_grad={}: NO GRADIENT",
1196                        i,
1197                        p.shape(),
1198                        p.requires_grad()
1199                    );
1200                }
1201            }
1202        }
1203        assert!(
1204            has_grad,
1205            "At least one GRU parameter should have non-zero gradients"
1206        );
1207    }
1208
1209    // =========================================================================
1210    // LSTM Comprehensive
1211    // =========================================================================
1212
1213    #[test]
1214    fn test_lstm_cell_forward_step() {
1215        let cell = LSTMCell::new(8, 16);
1216        let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1217        let hidden = Variable::new(
1218            Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1219            false,
1220        );
1221        let cell_state = Variable::new(
1222            Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1223            false,
1224        );
1225        let hx = (hidden, cell_state);
1226        let (h, c) = cell.forward_step(&input, &hx);
1227        assert_eq!(h.shape(), vec![2, 16]);
1228        assert_eq!(c.shape(), vec![2, 16]);
1229    }
1230
1231    #[test]
1232    fn test_lstm_multi_layer() {
1233        let lstm = LSTM::new(8, 16, 3); // 3 layers
1234        assert_eq!(lstm.num_layers(), 3);
1235        assert_eq!(lstm.hidden_size(), 16);
1236
1237        let input = Variable::new(
1238            Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1239            false,
1240        );
1241        let output = lstm.forward(&input);
1242        assert_eq!(output.shape(), vec![2, 5, 16]);
1243    }
1244
1245    #[test]
1246    fn test_lstm_forward_last() {
1247        let lstm = LSTM::new(8, 16, 1);
1248        let input = Variable::new(
1249            Tensor::from_vec(vec![1.0; 2 * 10 * 8], &[2, 10, 8]).unwrap(),
1250            false,
1251        );
1252        // forward_last should return only the last time step
1253        // The LSTM module may not have forward_last, but forward returns [B, T, H]
1254        let output = lstm.forward(&input);
1255        assert_eq!(output.shape(), vec![2, 10, 16]);
1256
1257        // Last timestep extraction
1258        let out_vec = output.data().to_vec();
1259        let last_t0 = &out_vec[9 * 16..10 * 16]; // batch 0, time 9
1260        assert!(
1261            last_t0.iter().all(|v| v.is_finite()),
1262            "Last output should be finite"
1263        );
1264    }
1265
1266    #[test]
1267    fn test_lstm_gradient_flow() {
1268        let lstm = LSTM::new(4, 8, 1);
1269        let input = Variable::new(
1270            Tensor::from_vec(vec![0.5; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
1271            true,
1272        );
1273        let output = lstm.forward(&input);
1274        let loss = output.sum();
1275        loss.backward();
1276
1277        let input_grad = input
1278            .grad()
1279            .expect("Input should have gradient through LSTM");
1280        assert_eq!(input_grad.shape(), &[1, 3, 4]);
1281        assert!(
1282            input_grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1283            "LSTM should propagate gradients to input"
1284        );
1285
1286        // Parameters should also have gradients
1287        let params = lstm.parameters();
1288        let grads_exist = params.iter().any(|p| {
1289            p.grad()
1290                .map(|g| g.to_vec().iter().any(|v| v.abs() > 0.0))
1291                .unwrap_or(false)
1292        });
1293        assert!(grads_exist, "LSTM parameters should have gradients");
1294    }
1295
1296    #[test]
1297    fn test_lstm_different_sequence_lengths() {
1298        let lstm = LSTM::new(4, 8, 1);
1299
1300        // Short sequence
1301        let short = Variable::new(
1302            Tensor::from_vec(vec![1.0; 1 * 2 * 4], &[1, 2, 4]).unwrap(),
1303            false,
1304        );
1305        let out_short = lstm.forward(&short);
1306        assert_eq!(out_short.shape(), vec![1, 2, 8]);
1307
1308        // Long sequence
1309        let long = Variable::new(
1310            Tensor::from_vec(vec![1.0; 1 * 20 * 4], &[1, 20, 4]).unwrap(),
1311            false,
1312        );
1313        let out_long = lstm.forward(&long);
1314        assert_eq!(out_long.shape(), vec![1, 20, 8]);
1315    }
1316
1317    #[test]
1318    fn test_lstm_parameters_count() {
1319        // LSTM has 4 gates (i, f, g, o), each with input and hidden weights + biases
1320        // Per layer: 4 * (input_size * hidden_size + hidden_size * hidden_size + 2 * hidden_size)
1321        let lstm = LSTM::new(10, 20, 1);
1322        let n = lstm.parameters().iter().map(|p| p.numel()).sum::<usize>();
1323        // Expected: 4 * (10*20 + 20*20 + 20 + 20) = 4 * (200 + 400 + 40) = 2560
1324        assert!(n > 0, "LSTM should have parameters");
1325    }
1326
1327    // =========================================================================
1328    // GRU Comprehensive
1329    // =========================================================================
1330
1331    #[test]
1332    fn test_gru_cell_forward_step() {
1333        let cell = GRUCell::new(8, 16);
1334        assert_eq!(cell.input_size(), 8);
1335        assert_eq!(cell.hidden_size(), 16);
1336
1337        let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1338        let hidden = Variable::new(
1339            Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1340            false,
1341        );
1342        let output = cell.forward_step(&input, &hidden);
1343        assert_eq!(output.shape(), vec![2, 16]);
1344    }
1345
1346    #[test]
1347    fn test_gru_multi_layer() {
1348        let gru = GRU::new(8, 16, 2);
1349        assert_eq!(gru.num_layers(), 2);
1350        assert_eq!(gru.hidden_size(), 16);
1351
1352        let input = Variable::new(
1353            Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1354            false,
1355        );
1356        let output = gru.forward(&input);
1357        assert_eq!(output.shape(), vec![2, 5, 16]);
1358    }
1359
1360    #[test]
1361    fn test_gru_forward_mean() {
1362        let gru = GRU::new(4, 8, 1);
1363        let input = Variable::new(
1364            Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1365            false,
1366        );
1367        let mean_out = gru.forward_mean(&input);
1368        // forward_mean averages over time: [B, T, H] → [B, H]
1369        assert_eq!(mean_out.shape(), vec![2, 8]);
1370    }
1371
1372    #[test]
1373    fn test_gru_forward_last() {
1374        let gru = GRU::new(4, 8, 1);
1375        let input = Variable::new(
1376            Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1377            false,
1378        );
1379        let last_out = gru.forward_last(&input);
1380        // forward_last returns only last timestep: [B, T, H] → [B, H]
1381        assert_eq!(last_out.shape(), vec![2, 8]);
1382    }
1383
1384    #[test]
1385    fn test_gru_gradient_flow_to_input() {
1386        let gru = GRU::new(4, 8, 1);
1387        let input = Variable::new(
1388            Tensor::from_vec(vec![0.5; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
1389            true,
1390        );
1391        let output = gru.forward(&input);
1392        output.sum().backward();
1393
1394        let grad = input
1395            .grad()
1396            .expect("Input should have gradient through GRU");
1397        assert_eq!(grad.shape(), &[1, 3, 4]);
1398        assert!(
1399            grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1400            "GRU should propagate gradients"
1401        );
1402    }
1403
1404    #[test]
1405    fn test_gru_hidden_state_evolves() {
1406        let gru = GRU::new(4, 8, 1);
1407        let input = Variable::new(
1408            Tensor::from_vec(vec![1.0; 1 * 5 * 4], &[1, 5, 4]).unwrap(),
1409            false,
1410        );
1411        let output = gru.forward(&input);
1412        let out_vec = output.data().to_vec();
1413
1414        // Hidden states at different timesteps should differ
1415        let t0 = &out_vec[0..8];
1416        let t4 = &out_vec[4 * 8..5 * 8];
1417        let diff: f32 = t0.iter().zip(t4.iter()).map(|(a, b)| (a - b).abs()).sum();
1418        assert!(
1419            diff > 1e-6,
1420            "GRU hidden state should evolve over time, diff={}",
1421            diff
1422        );
1423    }
1424
1425    // =========================================================================
1426    // RNN Basic
1427    // =========================================================================
1428
1429    #[test]
1430    fn test_rnn_cell_gradient_flow() {
1431        let cell = RNNCell::new(4, 8);
1432        let input = Variable::new(Tensor::from_vec(vec![1.0; 1 * 4], &[1, 4]).unwrap(), true);
1433        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 1 * 8], &[1, 8]).unwrap(), false);
1434        let out = cell.forward_step(&input, &hidden);
1435        out.sum().backward();
1436
1437        let grad = input.grad().expect("RNNCell should propagate gradients");
1438        assert_eq!(grad.shape(), &[1, 4]);
1439    }
1440
1441    #[test]
1442    fn test_rnn_multi_layer() {
1443        let rnn = RNN::with_options(8, 16, 3, true); // 3 layers, bias
1444        let input = Variable::new(
1445            Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1446            false,
1447        );
1448        let output = rnn.forward(&input);
1449        assert_eq!(output.shape(), vec![2, 5, 16]);
1450    }
1451
1452    // =========================================================================
1453    // Numerical Stability
1454    // =========================================================================
1455
1456    #[test]
1457    fn test_lstm_outputs_are_bounded() {
1458        // LSTM should produce bounded outputs (tanh output gate)
1459        let lstm = LSTM::new(4, 8, 1);
1460        let input = Variable::new(
1461            Tensor::from_vec(vec![100.0; 1 * 10 * 4], &[1, 10, 4]).unwrap(),
1462            false,
1463        );
1464        let output = lstm.forward(&input);
1465        let out_vec = output.data().to_vec();
1466
1467        // All outputs should be in [-1, 1] range (tanh bounded)
1468        for v in &out_vec {
1469            assert!(v.is_finite(), "LSTM output should be finite, got {}", v);
1470            assert!(
1471                v.abs() <= 1.0 + 1e-5,
1472                "LSTM output should be bounded by tanh: got {}",
1473                v
1474            );
1475        }
1476    }
1477
1478    #[test]
1479    fn test_gru_outputs_finite_with_large_input() {
1480        let gru = GRU::new(4, 8, 1);
1481        let input = Variable::new(
1482            Tensor::from_vec(vec![50.0; 1 * 5 * 4], &[1, 5, 4]).unwrap(),
1483            false,
1484        );
1485        let output = gru.forward(&input);
1486        assert!(
1487            output.data().to_vec().iter().all(|v| v.is_finite()),
1488            "GRU should produce finite outputs for large inputs"
1489        );
1490    }
1491}