Skip to main content

ferrotorch_nn/
rnn.rs

1//! Recurrent neural network modules.
2//!
3//! Implements [`LSTM`], [`GRU`], [`RNN`] (multi-layer modules) and their
4//! single-step cell counterparts [`LSTMCell`], [`GRUCell`], [`RNNCell`].
5//! Each mirrors the corresponding `torch.nn` module.
6//!
7//! Because the forward passes are composed entirely from differentiable
8//! operations (`mm`, `add`, `mul`, `sigmoid`, `tanh`, `relu` from
9//! `ferrotorch_core::grad_fns`), autograd builds the backward graph
10//! automatically — no custom backward functions are required.
11//!
12//! ## REQ status (per `.design/ferrotorch-nn/rnn.md`)
13//!
14//! | REQ | Status | Evidence |
15//! |---|---|---|
16//! | REQ-1 | SHIPPED | the `LSTM<T>` struct + private `LSTMLayerParams<T>` here; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:245` + `ferrotorch/src/lib.rs:50` |
17//! | REQ-2 | SHIPPED | the `LSTM::forward` body here with the four-gate update; non-test consumer: re-export at `lib.rs:245` + meta-crate prelude `ferrotorch/src/lib.rs:50` + `benchmarks/ferrotorch_bench.rs:178` |
18//! | REQ-3 | SHIPPED | the `GRU<T>` struct here; non-test consumer: re-export at `lib.rs:245` + `ferrotorch/src/lib.rs:50` + `benchmarks/ferrotorch_bench.rs:178` |
19//! | REQ-4 | SHIPPED | the `GRU::forward` body here with the three-gate update; non-test consumer: as REQ-3 |
20//! | REQ-5 | SHIPPED | the `RNNNonlinearity` enum here; non-test consumer: re-export at `lib.rs:245` |
21//! | REQ-6 | SHIPPED | the `RNN<T>` struct here; non-test consumer: re-export at `lib.rs:245` |
22//! | REQ-7 | SHIPPED | the `RNN::forward` body dispatching on `RNNNonlinearity`; non-test consumer: re-export at `lib.rs:245` |
23//! | REQ-8 | SHIPPED | the `LSTMCell<T>`, `GRUCell<T>`, `RNNCell<T>` structs here; non-test consumer: re-export at `lib.rs:245` |
24//! | REQ-9 | SHIPPED | `init::uniform_` and `init::zeros_` calls in the constructors here; non-test consumer: re-export at `lib.rs:245` |
25//! | REQ-10 | SHIPPED | the `impl<T: Float> Module<T>` blocks for every public struct here; non-test consumer: re-export at `lib.rs:245` |
26//! | REQ-11 | SHIPPED | the `use ferrotorch_core::grad_fns::linalg::mm_differentiable as mm` import here and its use in every forward path; non-test consumer: re-export at `lib.rs:245` |
27//! | REQ-12 | NOT-STARTED | parity-sweep runner arm for `nn.functional.lstm_cell` not wired — blocker #1456 |
28//! | REQ-13 | NOT-STARTED | parity-sweep runner arm for `nn.functional.gru_cell` not wired — blocker #1456 |
29//! | REQ-14 | NOT-STARTED | parity-sweep runner arm for `nn.functional.rnn_relu_cell` not wired — blocker #1456 |
30
31use ferrotorch_core::grad_fns::activation::{relu, sigmoid, tanh};
32use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
33use ferrotorch_core::grad_fns::shape::{cat, reshape};
34// Use the device-aware, autograd-tracked matmul. The `ops::linalg::mm`
35// alternative is host-only (calls `data()?` internally, which after the
36// Phase-2a `try_as_slice` migration returns Err on GPU storage). The
37// LSTM/GRU/RNN forward paths must dispatch to GPU when inputs are GPU-
38// resident — see #750 closure.
39use ferrotorch_core::grad_fns::linalg::mm_differentiable as mm;
40use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
41
42use crate::init;
43use crate::module::Module;
44use crate::parameter::Parameter;
45
46/// Output type for LSTM forward: `(output_sequence, (h_n, c_n))`.
47type LstmOutput<T> = (Tensor<T>, (Tensor<T>, Tensor<T>));
48
49// ---------------------------------------------------------------------------
50// Per-layer parameter set
51// ---------------------------------------------------------------------------
52
53/// Parameters for a single LSTM layer.
54#[derive(Debug, Clone)]
55struct LSTMLayerParams<T: Float> {
56    /// Weight matrix for input-to-hidden: shape [4*hidden_size, input_size].
57    weight_ih: Parameter<T>,
58    /// Weight matrix for hidden-to-hidden: shape [4*hidden_size, hidden_size].
59    weight_hh: Parameter<T>,
60    /// Bias for input-to-hidden: shape [4*hidden_size].
61    bias_ih: Parameter<T>,
62    /// Bias for hidden-to-hidden: shape [4*hidden_size].
63    bias_hh: Parameter<T>,
64}
65
66// ---------------------------------------------------------------------------
67// LSTM
68// ---------------------------------------------------------------------------
69
70/// A multi-layer Long Short-Term Memory (LSTM) RNN.
71///
72/// For each element in the input sequence, each layer computes:
73///
74/// ```text
75/// i = sigmoid(W_ii @ x + b_ii + W_hi @ h + b_hi)
76/// f = sigmoid(W_if @ x + b_if + W_hf @ h + b_hf)
77/// g = tanh(W_ig @ x + b_ig + W_hg @ h + b_hg)
78/// o = sigmoid(W_io @ x + b_io + W_ho @ h + b_ho)
79/// c' = f * c + i * g
80/// h' = o * tanh(c')
81/// ```
82///
83/// The weight matrices for all four gates are concatenated into a single
84/// `weight_ih` of shape `[4*hidden_size, input_size]` and `weight_hh` of
85/// shape `[4*hidden_size, hidden_size]`.
86///
87/// # Type parameter
88///
89/// `T` must implement [`Float`] — currently `f32` or `f64`.
90#[derive(Debug)]
91pub struct LSTM<T: Float> {
92    input_size: usize,
93    hidden_size: usize,
94    num_layers: usize,
95    layers: Vec<LSTMLayerParams<T>>,
96    training: bool,
97}
98
99impl<T: Float> LSTM<T> {
100    /// Create a new LSTM module.
101    ///
102    /// # Arguments
103    ///
104    /// * `input_size` — number of expected features in the input `x`.
105    /// * `hidden_size` — number of features in the hidden state `h`.
106    /// * `num_layers` — number of stacked LSTM layers (must be >= 1).
107    ///
108    /// # Weight initialization
109    ///
110    /// All weights are initialized from `U(-k, k)` where `k = 1/sqrt(hidden_size)`.
111    /// Biases are initialized to zero. This matches PyTorch's default.
112    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> FerrotorchResult<Self> {
113        if num_layers == 0 {
114            return Err(FerrotorchError::InvalidArgument {
115                message: "LSTM: num_layers must be >= 1".into(),
116            });
117        }
118        if hidden_size == 0 {
119            return Err(FerrotorchError::InvalidArgument {
120                message: "LSTM: hidden_size must be >= 1".into(),
121            });
122        }
123        if input_size == 0 {
124            return Err(FerrotorchError::InvalidArgument {
125                message: "LSTM: input_size must be >= 1".into(),
126            });
127        }
128
129        let k = 1.0 / (hidden_size as f64).sqrt();
130        let gate_size = 4 * hidden_size;
131
132        let mut layers = Vec::with_capacity(num_layers);
133
134        for layer_idx in 0..num_layers {
135            let layer_input_size = if layer_idx == 0 {
136                input_size
137            } else {
138                hidden_size
139            };
140
141            let mut weight_ih = Parameter::zeros(&[gate_size, layer_input_size])?;
142            let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
143            let mut bias_ih = Parameter::zeros(&[gate_size])?;
144            let mut bias_hh = Parameter::zeros(&[gate_size])?;
145
146            init::uniform(&mut weight_ih, -k, k)?;
147            init::uniform(&mut weight_hh, -k, k)?;
148            init::zeros(&mut bias_ih)?;
149            init::zeros(&mut bias_hh)?;
150
151            layers.push(LSTMLayerParams {
152                weight_ih,
153                weight_hh,
154                bias_ih,
155                bias_hh,
156            });
157        }
158
159        Ok(Self {
160            input_size,
161            hidden_size,
162            num_layers,
163            layers,
164            training: true,
165        })
166    }
167
168    /// Forward pass with explicit hidden state.
169    ///
170    /// # Arguments
171    ///
172    /// * `input` — input tensor of shape `[batch, seq_len, input_size]`.
173    /// * `state` — optional `(h_0, c_0)` each of shape `[num_layers, batch, hidden_size]`.
174    ///   If `None`, both are initialized to zeros.
175    ///
176    /// # Returns
177    ///
178    /// A tuple `(output, (h_n, c_n))` where:
179    /// - `output` has shape `[batch, seq_len, hidden_size]` (last layer outputs).
180    /// - `h_n`, `c_n` each have shape `[num_layers, batch, hidden_size]`.
181    pub fn forward_with_state(
182        &self,
183        input: &Tensor<T>,
184        state: Option<(&Tensor<T>, &Tensor<T>)>,
185    ) -> FerrotorchResult<LstmOutput<T>> {
186        // Validate input shape: [B, seq_len, input_size]
187        if input.ndim() != 3 {
188            return Err(FerrotorchError::InvalidArgument {
189                message: format!(
190                    "LSTM: expected 3-D input [batch, seq_len, input_size], got shape {:?}",
191                    input.shape()
192                ),
193            });
194        }
195
196        let batch = input.shape()[0];
197        let seq_len = input.shape()[1];
198
199        if input.shape()[2] != self.input_size {
200            return Err(FerrotorchError::ShapeMismatch {
201                message: format!(
202                    "LSTM: input_size mismatch: expected {}, got {}",
203                    self.input_size,
204                    input.shape()[2]
205                ),
206            });
207        }
208
209        // Initialize hidden / cell states.
210        let (h_init, c_init) = match state {
211            Some((h0, c0)) => {
212                // Validate shapes: [num_layers, batch, hidden_size]
213                let expected_shape = [self.num_layers, batch, self.hidden_size];
214                if h0.shape() != expected_shape {
215                    return Err(FerrotorchError::ShapeMismatch {
216                        message: format!(
217                            "LSTM: h_0 shape mismatch: expected {:?}, got {:?}",
218                            expected_shape,
219                            h0.shape()
220                        ),
221                    });
222                }
223                if c0.shape() != expected_shape {
224                    return Err(FerrotorchError::ShapeMismatch {
225                        message: format!(
226                            "LSTM: c_0 shape mismatch: expected {:?}, got {:?}",
227                            expected_shape,
228                            c0.shape()
229                        ),
230                    });
231                }
232                (h0.clone(), c0.clone())
233            }
234            None => {
235                // Initialize on the input's device so the layer state and
236                // gate computation share the same device. The CPU-only
237                // `zeros` would land h0/c0 on Cpu and trigger a
238                // DeviceMismatch downstream when the input is GPU-resident.
239                let init_shape = [self.num_layers, batch, self.hidden_size];
240                let h0 = ferrotorch_core::zeros::<T>(&init_shape)?.to(input.device())?;
241                let c0 = ferrotorch_core::zeros::<T>(&init_shape)?.to(input.device())?;
242                (h0, c0)
243            }
244        };
245
246        // Extract per-timestep input slices using device-aware narrow + squeeze.
247        // input is [batch, seq_len, input_size]; per-timestep slices are
248        // [batch, input_size]. Both ops preserve autograd state and the
249        // tensor's device (CUDA tensors stay on CUDA, CPU tensors stay on CPU).
250        let mut timestep_inputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
251        for t in 0..seq_len {
252            let slice = input.narrow(1, t, 1)?; // [batch, 1, input_size]
253            timestep_inputs.push(slice.squeeze_t(1)?); // [batch, input_size]
254        }
255
256        // Extract per-layer initial hidden/cell states via narrow + squeeze.
257        // h_init / c_init are [num_layers, batch, hidden_size]; each layer's
258        // slice is [batch, hidden_size]. Device-aware and autograd-preserving.
259        let hs = self.hidden_size;
260        let mut layer_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
261        let mut layer_c: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
262        for l in 0..self.num_layers {
263            layer_h.push(h_init.narrow(0, l, 1)?.squeeze_t(0)?);
264            layer_c.push(c_init.narrow(0, l, 1)?.squeeze_t(0)?);
265        }
266
267        // Run the LSTM forward pass.
268        // For each layer, iterate through all timesteps, then pass the
269        // sequence of hidden states as input to the next layer.
270        let mut layer_outputs: Vec<Tensor<T>> = timestep_inputs;
271        let mut final_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
272        let mut final_c: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
273
274        for (l, params) in self.layers.iter().enumerate() {
275            let mut h = layer_h[l].clone();
276            let mut c = layer_c[l].clone();
277            let mut next_layer_outputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
278
279            // Hoist weight transposes outside the timestep loop — the recurrent
280            // weights are constant across timesteps, so transposing once and
281            // reusing the result across all steps eliminates seq_len-1 redundant
282            // transposes per layer (the #1679 redundant-constant-transpose class
283            // multiplied by sequence length, #1680).
284            //
285            // Autograd correctness: one transpose node now feeds every per-step
286            // matmul instead of seq_len separate transpose nodes. This is value-
287            // and gradient-identical to the per-step version: transpose_2d is a
288            // differentiable view/permute, and the autograd engine accumulates
289            // the contributions from all consuming matmuls back through the
290            // single transpose node into the weight Parameter exactly as it would
291            // have summed across the per-step transpose nodes.
292            //
293            // W_ih: [4*hs, layer_input_size], need x_t @ W_ih^T => [batch, 4*hs]
294            // W_hh: [4*hs, hs], need h @ W_hh^T => [batch, 4*hs]
295            //
296            // `transpose_2d` is a zero-copy stride swap producing a NON-
297            // contiguous view. `mm_differentiable` materializes a contiguous
298            // copy of any non-contiguous operand on every call, so the
299            // per-step loop would re-copy the (constant) transposed weight
300            // seq_len times. Materialize the contiguous transpose ONCE here so
301            // the per-step `mm` sees an already-contiguous operand and skips
302            // that copy. `contiguous()` is a differentiable identity-on-values
303            // op (ContiguousBackward), so this is value- and gradient-
304            // preserving (#1680).
305            let wih_t = transpose_2d(params.weight_ih.tensor())?.contiguous()?;
306            let whh_t = transpose_2d(params.weight_hh.tensor())?.contiguous()?;
307
308            // Batch the input-to-hidden projection into ONE GEMM across all
309            // timesteps (#1690). The input projection `x_t @ W_ih^T + b_ih`
310            // has no time dependency, so the seq_len separate small
311            // [batch, in]@[in, 4*hs] GEMMs are replaced by a single
312            // [seq_len*batch, in]@[in, 4*hs] GEMM. Only the recurrent term
313            // `h @ W_hh^T` (which depends on the previous hidden state) stays
314            // inside the per-timestep loop.
315            //
316            // This mirrors upstream's `FullLayer::operator()` CPU path at
317            // `aten/src/ATen/native/RNN.cpp:863-869`, which computes
318            // `params.linear_ih(inputs)` over the whole stacked sequence then
319            // `unbind(0)` per timestep with `pre_compute_input=true`.
320            //
321            // Autograd correctness: this is a pure reassociation. Stacking the
322            // per-step inputs (`cat` dim 0) and slicing the projection back
323            // (`narrow` dim 0) are differentiable; the gradient to `weight_ih`
324            // now accumulates through ONE matmul node whose upstream grad is
325            // the concatenation of the per-step grads — exactly the sum the
326            // per-step version produced across seq_len separate matmul nodes.
327            // The #1690 live-torch parity test pins this (grads must match
328            // torch, not be doubled or dropped).
329            let bias_ih_2d = broadcast_bias_to_batch(&params.bias_ih, batch)?;
330            let bias_hh_2d = broadcast_bias_to_batch(&params.bias_hh, batch)?;
331            let xw_all = batched_input_projection(&layer_outputs, &wih_t)?;
332
333            for (t, _x_t) in layer_outputs.iter().enumerate() {
334                // gates = x_t @ W_ih^T + bias_ih + h @ W_hh^T + bias_hh
335                // The input projection x_t @ W_ih^T is precomputed; slice the
336                // [batch, 4*hs] block for this timestep out of the batched GEMM.
337                let xw = xw_all.narrow(0, t * batch, batch)?; // [batch, 4*hs]
338                let hw = mm(&h, &whh_t)?; // [batch, 4*hs]
339
340                let gates = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
341
342                // Split gates into i, f, g, o — each [batch, hs].
343                // Uses differentiable chunk to preserve the autograd graph.
344                let gate_chunks = gates.chunk(4, 1)?;
345                let i_pre = gate_chunks[0].clone();
346                let f_pre = gate_chunks[1].clone();
347                let g_pre = gate_chunks[2].clone();
348                let o_pre = gate_chunks[3].clone();
349
350                // Apply activations (differentiable ops — autograd will track).
351                let i_gate = sigmoid(&i_pre)?;
352                let f_gate = sigmoid(&f_pre)?;
353                let g_gate = tanh(&g_pre)?;
354                let o_gate = sigmoid(&o_pre)?;
355
356                // c_new = f * c + i * g
357                let fc = mul(&f_gate, &c)?;
358                let ig = mul(&i_gate, &g_gate)?;
359                let c_new = add(&fc, &ig)?;
360
361                // h_new = o * tanh(c_new)
362                let tanh_c = tanh(&c_new)?;
363                let h_new = mul(&o_gate, &tanh_c)?;
364
365                next_layer_outputs.push(h_new.clone());
366                h = h_new;
367                c = c_new;
368            }
369
370            final_h.push(h);
371            final_c.push(c);
372            layer_outputs = next_layer_outputs;
373        }
374
375        // Assemble output: [batch, seq_len, hidden_size] from the last layer.
376        // Each layer_outputs[t] is [batch, hs]. We need to interleave by batch
377        // to get [batch, seq_len, hs].
378        //
379        // Strategy: cat along dim=1 to get [batch, seq_len * hs], then reshape.
380        // But layer_outputs[t] is [batch, hs], and we want to stack them along
381        // a time dimension. Cat along dim=1 gives [batch, seq_len * hs].
382        let output = if seq_len == 1 {
383            // Single timestep: just reshape [batch, hs] -> [batch, 1, hs].
384            reshape(&layer_outputs[0], &[batch as isize, 1, hs as isize])?
385        } else {
386            // Cat timestep tensors along dim=1: [batch, seq_len*hs].
387            let stacked = cat(&layer_outputs, 1)?;
388            // Reshape to [batch, seq_len, hs].
389            reshape(&stacked, &[batch as isize, seq_len as isize, hs as isize])?
390        };
391
392        // Assemble h_n, c_n: [num_layers, batch, hidden_size].
393        // Cat final hidden states along dim=0: each is [batch, hs] -> [num_layers*batch, hs].
394        // Then reshape to [num_layers, batch, hs].
395        let h_n = if self.num_layers == 1 {
396            reshape(&final_h[0], &[1, batch as isize, hs as isize])?
397        } else {
398            let h_stacked = cat(&final_h, 0)?;
399            reshape(
400                &h_stacked,
401                &[self.num_layers as isize, batch as isize, hs as isize],
402            )?
403        };
404        let c_n = if self.num_layers == 1 {
405            reshape(&final_c[0], &[1, batch as isize, hs as isize])?
406        } else {
407            let c_stacked = cat(&final_c, 0)?;
408            reshape(
409                &c_stacked,
410                &[self.num_layers as isize, batch as isize, hs as isize],
411            )?
412        };
413
414        Ok((output, (h_n, c_n)))
415    }
416
417    /// Number of expected input features.
418    #[inline]
419    pub fn input_size(&self) -> usize {
420        self.input_size
421    }
422
423    /// Number of features in the hidden state.
424    #[inline]
425    pub fn hidden_size(&self) -> usize {
426        self.hidden_size
427    }
428
429    /// Number of stacked LSTM layers.
430    #[inline]
431    pub fn num_layers(&self) -> usize {
432        self.num_layers
433    }
434}
435
436// ---------------------------------------------------------------------------
437// Module trait implementation
438// ---------------------------------------------------------------------------
439
440impl<T: Float> Module<T> for LSTM<T> {
441    /// Forward pass using the `Module` interface (no explicit hidden state).
442    ///
443    /// Hidden state defaults to zeros. To pass initial state, use
444    /// [`LSTM::forward_with_state`] instead.
445    ///
446    /// Returns the output tensor of shape `[batch, seq_len, hidden_size]`.
447    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
448        let (output, _) = self.forward_with_state(input, None)?;
449        Ok(output)
450    }
451
452    fn parameters(&self) -> Vec<&Parameter<T>> {
453        let mut params = Vec::with_capacity(self.num_layers * 4);
454        for layer in &self.layers {
455            params.push(&layer.weight_ih);
456            params.push(&layer.weight_hh);
457            params.push(&layer.bias_ih);
458            params.push(&layer.bias_hh);
459        }
460        params
461    }
462
463    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
464        let mut params = Vec::with_capacity(self.num_layers * 4);
465        for layer in &mut self.layers {
466            params.push(&mut layer.weight_ih);
467            params.push(&mut layer.weight_hh);
468            params.push(&mut layer.bias_ih);
469            params.push(&mut layer.bias_hh);
470        }
471        params
472    }
473
474    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
475        let mut params = Vec::with_capacity(self.num_layers * 4);
476        for (i, layer) in self.layers.iter().enumerate() {
477            params.push((format!("layers.{i}.weight_ih"), &layer.weight_ih));
478            params.push((format!("layers.{i}.weight_hh"), &layer.weight_hh));
479            params.push((format!("layers.{i}.bias_ih"), &layer.bias_ih));
480            params.push((format!("layers.{i}.bias_hh"), &layer.bias_hh));
481        }
482        params
483    }
484
485    fn train(&mut self) {
486        self.training = true;
487    }
488
489    fn eval(&mut self) {
490        self.training = false;
491    }
492
493    fn is_training(&self) -> bool {
494        self.training
495    }
496}
497
498// ---------------------------------------------------------------------------
499// Helpers
500// ---------------------------------------------------------------------------
501
502/// Transpose a 2-D tensor (zero-copy stride swap; device-aware).
503///
504/// Delegates to `grad_fns::shape::transpose_2d` (which uses `permute_t`)
505/// rather than `ops::linalg::transpose` because the latter calls `data()`
506/// on the input — broken on GPU storage after the Phase-2a `try_as_slice`
507/// migration. The grad_fns variant is the device-aware path (see #750
508/// closure for context).
509fn transpose_2d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
510    ferrotorch_core::grad_fns::shape::transpose_2d(input)
511}
512
513/// Batch the input-to-hidden projection across all timesteps into ONE GEMM
514/// (#1690).
515///
516/// `step_inputs` is the per-timestep sequence of `[batch, layer_input]`
517/// tensors (already extracted via `narrow` + `squeeze` from the layer input).
518/// This stacks them along a new leading axis into `[seq_len*batch,
519/// layer_input]` and runs the single RAW matmul `X_all @ W_ih^T`, returning
520/// `[seq_len*batch, gate_size]`. The result is sliced per timestep by the
521/// caller via `narrow(0, t*batch, batch)`. The bias is intentionally NOT
522/// folded in here: the LSTM/RNN generic paths add `bias_ih` after slicing,
523/// and the GRU GPU fused kernel takes the raw (bias-free) projection plus the
524/// biases as separate arguments. Keeping the projection bias-free preserves
525/// both contracts.
526///
527/// This is the same reassociation upstream performs in `FullLayer::operator()`
528/// for CPU inputs (`aten/src/ATen/native/RNN.cpp:863-869`): project the whole
529/// stacked sequence then consume per timestep with `pre_compute_input=true`.
530///
531/// Every op here (`cat`, `mm`) is differentiable and device-aware, so this
532/// preserves autograd and device placement. The gradient to `weight_ih`
533/// accumulates through one matmul node — value- and gradient-identical to
534/// summing across the per-step matmul nodes.
535fn batched_input_projection<T: Float>(
536    step_inputs: &[Tensor<T>],
537    wih_t: &Tensor<T>,
538) -> FerrotorchResult<Tensor<T>> {
539    if step_inputs.len() == 1 {
540        // Single timestep: no stacking needed.
541        return mm(&step_inputs[0], wih_t);
542    }
543    // Stack per-step [batch, in] -> [seq_len*batch, in] along dim 0, then run
544    // the single fused input-projection GEMM.
545    let x_all = cat(step_inputs, 0)?;
546    mm(&x_all, wih_t) // [seq_len*batch, gate_size]
547}
548
549/// Broadcast a 1-D bias of shape `[n]` into shape `[batch, n]`.
550///
551/// Device-aware: preserves the bias tensor's device (CPU bias stays CPU,
552/// CUDA bias stays CUDA) and autograd state, since both `unsqueeze_t` and
553/// `expand` are differentiable, device-dispatched ops.
554fn broadcast_bias_to_batch<T: Float>(
555    bias: &Parameter<T>,
556    batch: usize,
557) -> FerrotorchResult<Tensor<T>> {
558    let n = bias.tensor().shape()[0];
559    let bias_2d = bias.tensor().unsqueeze_t(0)?; // [1, n]
560    ferrotorch_core::grad_fns::shape::expand(&bias_2d, &[batch, n])
561}
562
563// ---------------------------------------------------------------------------
564// Per-layer parameter set (GRU)
565// ---------------------------------------------------------------------------
566
567/// Parameters for a single GRU layer.
568#[derive(Debug, Clone)]
569struct GRULayerParams<T: Float> {
570    /// Weight matrix for input-to-hidden: shape [3*hidden_size, input_size].
571    weight_ih: Parameter<T>,
572    /// Weight matrix for hidden-to-hidden: shape [3*hidden_size, hidden_size].
573    weight_hh: Parameter<T>,
574    /// Bias for input-to-hidden: shape [3*hidden_size].
575    bias_ih: Parameter<T>,
576    /// Bias for hidden-to-hidden: shape [3*hidden_size].
577    bias_hh: Parameter<T>,
578}
579
580// ---------------------------------------------------------------------------
581// GRU
582// ---------------------------------------------------------------------------
583
584/// A multi-layer Gated Recurrent Unit (GRU) RNN.
585///
586/// For each element in the input sequence, each layer computes:
587///
588/// ```text
589/// r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)   // reset gate
590/// z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)   // update gate
591/// n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))  // new gate
592/// h_t = (1 - z_t) * n_t + z_t * h_{t-1}
593/// ```
594///
595/// The weight matrices for all three gates are concatenated into a single
596/// `weight_ih` of shape `[3*hidden_size, input_size]` and `weight_hh` of
597/// shape `[3*hidden_size, hidden_size]`.
598///
599/// # Type parameter
600///
601/// `T` must implement [`Float`] — currently `f32` or `f64`.
602#[derive(Debug)]
603pub struct GRU<T: Float> {
604    input_size: usize,
605    hidden_size: usize,
606    num_layers: usize,
607    layers: Vec<GRULayerParams<T>>,
608    training: bool,
609}
610
611impl<T: Float> GRU<T> {
612    /// Create a new GRU module.
613    ///
614    /// # Arguments
615    ///
616    /// * `input_size` — number of expected features in the input `x`.
617    /// * `hidden_size` — number of features in the hidden state `h`.
618    ///
619    /// Creates a single-layer GRU. Use [`GRU::with_num_layers`] for stacked
620    /// layers.
621    ///
622    /// # Weight initialization
623    ///
624    /// All weights are initialized from `U(-k, k)` where `k = 1/sqrt(hidden_size)`.
625    /// Biases are initialized to zero. This matches PyTorch's default.
626    pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
627        Self::with_num_layers(input_size, hidden_size, 1)
628    }
629
630    /// Create a new GRU module with multiple stacked layers.
631    ///
632    /// # Arguments
633    ///
634    /// * `input_size` — number of expected features in the input `x`.
635    /// * `hidden_size` — number of features in the hidden state `h`.
636    /// * `num_layers` — number of stacked GRU layers (must be >= 1).
637    pub fn with_num_layers(
638        input_size: usize,
639        hidden_size: usize,
640        num_layers: usize,
641    ) -> FerrotorchResult<Self> {
642        if num_layers == 0 {
643            return Err(FerrotorchError::InvalidArgument {
644                message: "GRU: num_layers must be >= 1".into(),
645            });
646        }
647        if hidden_size == 0 {
648            return Err(FerrotorchError::InvalidArgument {
649                message: "GRU: hidden_size must be >= 1".into(),
650            });
651        }
652        if input_size == 0 {
653            return Err(FerrotorchError::InvalidArgument {
654                message: "GRU: input_size must be >= 1".into(),
655            });
656        }
657
658        let k = 1.0 / (hidden_size as f64).sqrt();
659        let gate_size = 3 * hidden_size;
660
661        let mut layers = Vec::with_capacity(num_layers);
662
663        for layer_idx in 0..num_layers {
664            let layer_input_size = if layer_idx == 0 {
665                input_size
666            } else {
667                hidden_size
668            };
669
670            let mut weight_ih = Parameter::zeros(&[gate_size, layer_input_size])?;
671            let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
672            let mut bias_ih = Parameter::zeros(&[gate_size])?;
673            let mut bias_hh = Parameter::zeros(&[gate_size])?;
674
675            init::uniform(&mut weight_ih, -k, k)?;
676            init::uniform(&mut weight_hh, -k, k)?;
677            init::zeros(&mut bias_ih)?;
678            init::zeros(&mut bias_hh)?;
679
680            layers.push(GRULayerParams {
681                weight_ih,
682                weight_hh,
683                bias_ih,
684                bias_hh,
685            });
686        }
687
688        Ok(Self {
689            input_size,
690            hidden_size,
691            num_layers,
692            layers,
693            training: true,
694        })
695    }
696
697    /// Forward pass with explicit hidden state.
698    ///
699    /// # Arguments
700    ///
701    /// * `input` — input tensor of shape `[batch, seq_len, input_size]`.
702    /// * `h_0` — optional hidden state of shape `[num_layers, batch, hidden_size]`.
703    ///   If `None`, initialized to zeros.
704    ///
705    /// # Returns
706    ///
707    /// A tuple `(output, h_n)` where:
708    /// - `output` has shape `[batch, seq_len, hidden_size]` (last layer outputs).
709    /// - `h_n` has shape `[num_layers, batch, hidden_size]`.
710    pub fn forward(
711        &self,
712        input: &Tensor<T>,
713        h_0: Option<&Tensor<T>>,
714    ) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
715        // Validate input shape: [B, seq_len, input_size]
716        if input.ndim() != 3 {
717            return Err(FerrotorchError::InvalidArgument {
718                message: format!(
719                    "GRU: expected 3-D input [batch, seq_len, input_size], got shape {:?}",
720                    input.shape()
721                ),
722            });
723        }
724
725        let batch = input.shape()[0];
726        let seq_len = input.shape()[1];
727
728        if input.shape()[2] != self.input_size {
729            return Err(FerrotorchError::ShapeMismatch {
730                message: format!(
731                    "GRU: input_size mismatch: expected {}, got {}",
732                    self.input_size,
733                    input.shape()[2]
734                ),
735            });
736        }
737
738        // Initialize hidden state.
739        let h_init = match h_0 {
740            Some(h0) => {
741                let expected_shape = [self.num_layers, batch, self.hidden_size];
742                if h0.shape() != expected_shape {
743                    return Err(FerrotorchError::ShapeMismatch {
744                        message: format!(
745                            "GRU: h_0 shape mismatch: expected {:?}, got {:?}",
746                            expected_shape,
747                            h0.shape()
748                        ),
749                    });
750                }
751                h0.clone()
752            }
753            None => ferrotorch_core::zeros::<T>(&[self.num_layers, batch, self.hidden_size])?,
754        };
755
756        // Extract per-timestep input slices via device-aware narrow + squeeze.
757        // input is [batch, seq_len, input_size]; each slice is [batch, input_size].
758        let hs = self.hidden_size;
759        let mut timestep_inputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
760        for t in 0..seq_len {
761            let slice = input.narrow(1, t, 1)?; // [batch, 1, input_size]
762            timestep_inputs.push(slice.squeeze_t(1)?); // [batch, input_size]
763        }
764
765        // Extract per-layer initial hidden states via narrow + squeeze.
766        // h_init is [num_layers, batch, hidden_size]; per-layer is [batch, hs].
767        let mut layer_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
768        for l in 0..self.num_layers {
769            layer_h.push(h_init.narrow(0, l, 1)?.squeeze_t(0)?);
770        }
771
772        // Run the GRU forward pass.
773        let mut layer_outputs: Vec<Tensor<T>> = timestep_inputs;
774        let mut final_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
775
776        let is_f32 = std::mem::size_of::<T>() == 4;
777
778        for (l, params) in self.layers.iter().enumerate() {
779            let mut h = layer_h[l].clone();
780            let mut next_layer_outputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
781
782            // Hoist weight transposes outside the timestep loop — these are
783            // constant across timesteps. Materialize the contiguous transpose
784            // once so the per-step `mm` skips re-copying the non-contiguous
785            // stride-swapped view every timestep (#1680). `contiguous()` is a
786            // differentiable identity-on-values op, so this is value- and
787            // gradient-preserving.
788            let wih_t = ferrotorch_core::grad_fns::shape::transpose_2d(params.weight_ih.tensor())?
789                .contiguous()?;
790            let whh_t = ferrotorch_core::grad_fns::shape::transpose_2d(params.weight_hh.tensor())?
791                .contiguous()?;
792
793            // Batch the input-to-hidden projection into ONE GEMM across all
794            // timesteps (#1690): the seq_len small [batch, in]@[in, 3*hs]
795            // input GEMMs collapse to a single [seq_len*batch, in]@[in, 3*hs]
796            // GEMM. Only the recurrent h@W_hh^T stays per-step. The raw
797            // (bias-free) projection is what BOTH the GPU fused kernel and the
798            // generic CPU path expect for `xw`, so the per-path bias handling
799            // below is unchanged. Mirrors upstream `RNN.cpp:863-869`
800            // (`linear_ih` over the stacked sequence + `pre_compute_input`).
801            let xw_all = batched_input_projection(&layer_outputs, &wih_t)?;
802
803            // Check if we can use the fused GPU kernel.
804            let use_fused_gpu =
805                is_f32 && h.is_cuda() && ferrotorch_core::gpu_dispatch::gpu_backend().is_some();
806
807            for (t, _x_t) in layer_outputs.iter().enumerate() {
808                // Phase 1: slice this timestep's precomputed input projection
809                // and compute the recurrent gate matrix via cuBLAS GEMM.
810                let xw = xw_all.narrow(0, t * batch, batch)?; // [batch, 3*hs]
811                let hw = mm(&h, &whh_t)?; // [batch, 3*hs]
812
813                if use_fused_gpu {
814                    // ---- GPU fast path: fused pointwise kernel ----
815                    // The kernel takes raw gate matrices (no bias added) + biases
816                    // and computes all gate activations + GRU update in one launch.
817                    // The per-step `xw` is a `narrow` view into the batched
818                    // projection buffer; the fused kernel reads from the buffer
819                    // start (offset-unaware), so materialize a contiguous,
820                    // offset-0 copy of this timestep's [batch, 3*hs] block
821                    // before handing its handle to the kernel (#1690).
822                    // `contiguous()` is a differentiable identity-on-values op.
823                    let xw_c = xw.contiguous()?;
824                    let backend = ferrotorch_core::gpu_dispatch::gpu_backend()
825                        .ok_or(FerrotorchError::DeviceUnavailable)?;
826                    let (hy_handle, _workspace) = backend.fused_gru_cell_f32(
827                        xw_c.gpu_handle()?,
828                        hw.gpu_handle()?,
829                        params.bias_ih.tensor().gpu_handle()?,
830                        params.bias_hh.tensor().gpu_handle()?,
831                        h.gpu_handle()?,
832                        hs,
833                    )?;
834                    let h_new = Tensor::from_storage(
835                        TensorStorage::gpu(hy_handle),
836                        vec![batch, hs],
837                        false,
838                    )?;
839                    next_layer_outputs.push(h_new.clone());
840                    h = h_new;
841                } else {
842                    // ---- Generic path: device-aware composite ops ----
843                    // Used when the fused GPU kernel doesn't apply (e.g. f64
844                    // tensors, CPU tensors, or no GPU backend registered).
845                    // Every op below is device-dispatched, so this branch
846                    // runs on CPU for CPU tensors and on GPU for GPU tensors.
847                    let bias_ih_2d = broadcast_bias_to_batch(&params.bias_ih, batch)?;
848                    let bias_hh_2d = broadcast_bias_to_batch(&params.bias_hh, batch)?;
849
850                    let xw_b = add(&xw, &bias_ih_2d)?;
851                    let hw_b = add(&hw, &bias_hh_2d)?;
852
853                    // Split [batch, 3*hs] -> 3 x [batch, hs] via differentiable chunk.
854                    let xw_chunks = xw_b.chunk(3, 1)?;
855                    let hw_chunks = hw_b.chunk(3, 1)?;
856                    let rx = xw_chunks[0].clone();
857                    let zx = xw_chunks[1].clone();
858                    let nx = xw_chunks[2].clone();
859                    let rh = hw_chunks[0].clone();
860                    let zh = hw_chunks[1].clone();
861                    let nh = hw_chunks[2].clone();
862
863                    let r_gate = sigmoid(&add(&rx, &rh)?)?;
864                    let z_gate = sigmoid(&add(&zx, &zh)?)?;
865                    let r_nh = mul(&r_gate, &nh)?;
866                    let n_gate = tanh(&add(&nx, &r_nh)?)?;
867                    let h_minus_n = sub(&h, &n_gate)?;
868                    let z_h_minus_n = mul(&z_gate, &h_minus_n)?;
869                    let h_new = add(&n_gate, &z_h_minus_n)?;
870
871                    next_layer_outputs.push(h_new.clone());
872                    h = h_new;
873                }
874            }
875
876            final_h.push(h);
877            layer_outputs = next_layer_outputs;
878        }
879
880        // Assemble output: [batch, seq_len, hidden_size] from the last layer.
881        // Each layer_outputs[t] is [batch, hs]. Concatenate along dim=1 to
882        // produce [batch, seq_len*hs] then reshape — matches the layout used
883        // by LSTM above and preserves device + autograd.
884        let output = if seq_len == 1 {
885            reshape(&layer_outputs[0], &[batch as isize, 1, hs as isize])?
886        } else {
887            let stacked = cat(&layer_outputs, 1)?;
888            reshape(&stacked, &[batch as isize, seq_len as isize, hs as isize])?
889        };
890
891        // Assemble h_n: [num_layers, batch, hidden_size]. Each final_h[l] is
892        // [batch, hs]; concatenate along dim=0 then reshape to recover the
893        // layer dimension.
894        let h_n = if self.num_layers == 1 {
895            reshape(&final_h[0], &[1, batch as isize, hs as isize])?
896        } else {
897            let h_stacked = cat(&final_h, 0)?;
898            reshape(
899                &h_stacked,
900                &[self.num_layers as isize, batch as isize, hs as isize],
901            )?
902        };
903
904        Ok((output, h_n))
905    }
906
907    /// Number of expected input features.
908    #[inline]
909    pub fn input_size(&self) -> usize {
910        self.input_size
911    }
912
913    /// Number of features in the hidden state.
914    #[inline]
915    pub fn hidden_size(&self) -> usize {
916        self.hidden_size
917    }
918
919    /// Number of stacked GRU layers.
920    #[inline]
921    pub fn num_layers(&self) -> usize {
922        self.num_layers
923    }
924}
925
926// ---------------------------------------------------------------------------
927// Module trait implementation (GRU)
928// ---------------------------------------------------------------------------
929
930impl<T: Float> Module<T> for GRU<T> {
931    /// Forward pass using the `Module` interface (no explicit hidden state).
932    ///
933    /// Hidden state defaults to zeros. To pass initial state, use
934    /// [`GRU::forward`] instead.
935    ///
936    /// Returns the output tensor of shape `[batch, seq_len, hidden_size]`.
937    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
938        let (output, _) = GRU::forward(self, input, None)?;
939        Ok(output)
940    }
941
942    fn parameters(&self) -> Vec<&Parameter<T>> {
943        let mut params = Vec::with_capacity(self.num_layers * 4);
944        for layer in &self.layers {
945            params.push(&layer.weight_ih);
946            params.push(&layer.weight_hh);
947            params.push(&layer.bias_ih);
948            params.push(&layer.bias_hh);
949        }
950        params
951    }
952
953    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
954        let mut params = Vec::with_capacity(self.num_layers * 4);
955        for layer in &mut self.layers {
956            params.push(&mut layer.weight_ih);
957            params.push(&mut layer.weight_hh);
958            params.push(&mut layer.bias_ih);
959            params.push(&mut layer.bias_hh);
960        }
961        params
962    }
963
964    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
965        let mut params = Vec::with_capacity(self.num_layers * 4);
966        for (i, layer) in self.layers.iter().enumerate() {
967            params.push((format!("layers.{i}.weight_ih"), &layer.weight_ih));
968            params.push((format!("layers.{i}.weight_hh"), &layer.weight_hh));
969            params.push((format!("layers.{i}.bias_ih"), &layer.bias_ih));
970            params.push((format!("layers.{i}.bias_hh"), &layer.bias_hh));
971        }
972        params
973    }
974
975    fn train(&mut self) {
976        self.training = true;
977    }
978
979    fn eval(&mut self) {
980        self.training = false;
981    }
982
983    fn is_training(&self) -> bool {
984        self.training
985    }
986}
987
988// ===========================================================================
989// RNNCell
990// ===========================================================================
991
992/// Nonlinearity for [`RNNCell`] and [`RNN`].
993#[derive(Debug, Clone, Copy, PartialEq, Eq)]
994pub enum RNNNonlinearity {
995    /// Hyperbolic tangent (default).
996    Tanh,
997    /// Rectified linear unit.
998    ReLU,
999}
1000
1001/// A single-step vanilla RNN cell.
1002///
1003/// Computes `h' = nonlinearity(x @ W_ih^T + b_ih + h @ W_hh^T + b_hh)`.
1004///
1005/// This is the equivalent of `torch.nn.RNNCell`.
1006#[derive(Debug)]
1007pub struct RNNCell<T: Float> {
1008    input_size: usize,
1009    hidden_size: usize,
1010    nonlinearity: RNNNonlinearity,
1011    weight_ih: Parameter<T>,
1012    weight_hh: Parameter<T>,
1013    bias_ih: Parameter<T>,
1014    bias_hh: Parameter<T>,
1015    training: bool,
1016}
1017
1018impl<T: Float> RNNCell<T> {
1019    /// Create a new `RNNCell`.
1020    ///
1021    /// # Arguments
1022    ///
1023    /// * `input_size` — number of expected features in `x`.
1024    /// * `hidden_size` — number of features in the hidden state `h`.
1025    ///
1026    /// Uses tanh nonlinearity by default. Call [`RNNCell::with_nonlinearity`]
1027    /// for relu.
1028    pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1029        Self::with_nonlinearity(input_size, hidden_size, RNNNonlinearity::Tanh)
1030    }
1031
1032    /// Create a new `RNNCell` with a specified nonlinearity.
1033    pub fn with_nonlinearity(
1034        input_size: usize,
1035        hidden_size: usize,
1036        nonlinearity: RNNNonlinearity,
1037    ) -> FerrotorchResult<Self> {
1038        if hidden_size == 0 {
1039            return Err(FerrotorchError::InvalidArgument {
1040                message: "RNNCell: hidden_size must be >= 1".into(),
1041            });
1042        }
1043        if input_size == 0 {
1044            return Err(FerrotorchError::InvalidArgument {
1045                message: "RNNCell: input_size must be >= 1".into(),
1046            });
1047        }
1048
1049        let k = 1.0 / (hidden_size as f64).sqrt();
1050
1051        let mut weight_ih = Parameter::zeros(&[hidden_size, input_size])?;
1052        let mut weight_hh = Parameter::zeros(&[hidden_size, hidden_size])?;
1053        let mut bias_ih = Parameter::zeros(&[hidden_size])?;
1054        let mut bias_hh = Parameter::zeros(&[hidden_size])?;
1055
1056        init::uniform(&mut weight_ih, -k, k)?;
1057        init::uniform(&mut weight_hh, -k, k)?;
1058        init::zeros(&mut bias_ih)?;
1059        init::zeros(&mut bias_hh)?;
1060
1061        Ok(Self {
1062            input_size,
1063            hidden_size,
1064            nonlinearity,
1065            weight_ih,
1066            weight_hh,
1067            bias_ih,
1068            bias_hh,
1069            training: true,
1070        })
1071    }
1072
1073    /// Forward pass for the RNN cell.
1074    ///
1075    /// # Arguments
1076    ///
1077    /// * `input` — input tensor of shape `[batch, input_size]`.
1078    /// * `h` — hidden state of shape `[batch, hidden_size]`. If `None`,
1079    ///   initialized to zeros.
1080    ///
1081    /// # Returns
1082    ///
1083    /// New hidden state `h'` of shape `[batch, hidden_size]`.
1084    pub fn forward_cell(
1085        &self,
1086        input: &Tensor<T>,
1087        h: Option<&Tensor<T>>,
1088    ) -> FerrotorchResult<Tensor<T>> {
1089        if input.ndim() != 2 {
1090            return Err(FerrotorchError::InvalidArgument {
1091                message: format!(
1092                    "RNNCell: expected 2-D input [batch, input_size], got shape {:?}",
1093                    input.shape()
1094                ),
1095            });
1096        }
1097        let batch = input.shape()[0];
1098        if input.shape()[1] != self.input_size {
1099            return Err(FerrotorchError::ShapeMismatch {
1100                message: format!(
1101                    "RNNCell: input_size mismatch: expected {}, got {}",
1102                    self.input_size,
1103                    input.shape()[1]
1104                ),
1105            });
1106        }
1107
1108        let h_state = match h {
1109            Some(h0) => {
1110                if h0.shape() != [batch, self.hidden_size] {
1111                    return Err(FerrotorchError::ShapeMismatch {
1112                        message: format!(
1113                            "RNNCell: h shape mismatch: expected {:?}, got {:?}",
1114                            [batch, self.hidden_size],
1115                            h0.shape()
1116                        ),
1117                    });
1118                }
1119                h0.clone()
1120            }
1121            None => ferrotorch_core::zeros::<T>(&[batch, self.hidden_size])?,
1122        };
1123
1124        let wih_t = transpose_2d(self.weight_ih.tensor())?;
1125        let whh_t = transpose_2d(self.weight_hh.tensor())?;
1126
1127        let xw = mm(input, &wih_t)?; // [batch, hidden_size]
1128        let hw = mm(&h_state, &whh_t)?; // [batch, hidden_size]
1129
1130        let bias_ih_2d = broadcast_bias_to_batch(&self.bias_ih, batch)?;
1131        let bias_hh_2d = broadcast_bias_to_batch(&self.bias_hh, batch)?;
1132
1133        let pre_act = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
1134
1135        match self.nonlinearity {
1136            RNNNonlinearity::Tanh => tanh(&pre_act),
1137            RNNNonlinearity::ReLU => relu(&pre_act),
1138        }
1139    }
1140
1141    /// Number of expected input features.
1142    #[inline]
1143    pub fn input_size(&self) -> usize {
1144        self.input_size
1145    }
1146
1147    /// Number of features in the hidden state.
1148    #[inline]
1149    pub fn hidden_size(&self) -> usize {
1150        self.hidden_size
1151    }
1152
1153    /// The nonlinearity used by this cell.
1154    #[inline]
1155    pub fn nonlinearity(&self) -> RNNNonlinearity {
1156        self.nonlinearity
1157    }
1158}
1159
1160impl<T: Float> Module<T> for RNNCell<T> {
1161    /// Forward with zero initial hidden state.
1162    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1163        self.forward_cell(input, None)
1164    }
1165
1166    fn parameters(&self) -> Vec<&Parameter<T>> {
1167        vec![
1168            &self.weight_ih,
1169            &self.weight_hh,
1170            &self.bias_ih,
1171            &self.bias_hh,
1172        ]
1173    }
1174
1175    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1176        vec![
1177            &mut self.weight_ih,
1178            &mut self.weight_hh,
1179            &mut self.bias_ih,
1180            &mut self.bias_hh,
1181        ]
1182    }
1183
1184    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1185        vec![
1186            ("weight_ih".into(), &self.weight_ih),
1187            ("weight_hh".into(), &self.weight_hh),
1188            ("bias_ih".into(), &self.bias_ih),
1189            ("bias_hh".into(), &self.bias_hh),
1190        ]
1191    }
1192
1193    fn train(&mut self) {
1194        self.training = true;
1195    }
1196
1197    fn eval(&mut self) {
1198        self.training = false;
1199    }
1200
1201    fn is_training(&self) -> bool {
1202        self.training
1203    }
1204}
1205
1206// ===========================================================================
1207// LSTMCell
1208// ===========================================================================
1209
1210/// A single-step LSTM cell.
1211///
1212/// Computes:
1213/// ```text
1214/// gates = x @ W_ih^T + b_ih + h @ W_hh^T + b_hh
1215/// i = sigmoid(gates[0:H])
1216/// f = sigmoid(gates[H:2H])
1217/// g = tanh(gates[2H:3H])
1218/// o = sigmoid(gates[3H:4H])
1219/// c' = f * c + i * g
1220/// h' = o * tanh(c')
1221/// ```
1222///
1223/// This is the equivalent of `torch.nn.LSTMCell`.
1224#[derive(Debug)]
1225pub struct LSTMCell<T: Float> {
1226    input_size: usize,
1227    hidden_size: usize,
1228    weight_ih: Parameter<T>,
1229    weight_hh: Parameter<T>,
1230    bias_ih: Parameter<T>,
1231    bias_hh: Parameter<T>,
1232    training: bool,
1233}
1234
1235impl<T: Float> LSTMCell<T> {
1236    /// Create a new `LSTMCell`.
1237    ///
1238    /// # Arguments
1239    ///
1240    /// * `input_size` — number of expected features in `x`.
1241    /// * `hidden_size` — number of features in the hidden state `h`.
1242    pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1243        if hidden_size == 0 {
1244            return Err(FerrotorchError::InvalidArgument {
1245                message: "LSTMCell: hidden_size must be >= 1".into(),
1246            });
1247        }
1248        if input_size == 0 {
1249            return Err(FerrotorchError::InvalidArgument {
1250                message: "LSTMCell: input_size must be >= 1".into(),
1251            });
1252        }
1253
1254        let k = 1.0 / (hidden_size as f64).sqrt();
1255        let gate_size = 4 * hidden_size;
1256
1257        let mut weight_ih = Parameter::zeros(&[gate_size, input_size])?;
1258        let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
1259        let mut bias_ih = Parameter::zeros(&[gate_size])?;
1260        let mut bias_hh = Parameter::zeros(&[gate_size])?;
1261
1262        init::uniform(&mut weight_ih, -k, k)?;
1263        init::uniform(&mut weight_hh, -k, k)?;
1264        init::zeros(&mut bias_ih)?;
1265        init::zeros(&mut bias_hh)?;
1266
1267        Ok(Self {
1268            input_size,
1269            hidden_size,
1270            weight_ih,
1271            weight_hh,
1272            bias_ih,
1273            bias_hh,
1274            training: true,
1275        })
1276    }
1277
1278    /// Forward pass for the LSTM cell.
1279    ///
1280    /// # Arguments
1281    ///
1282    /// * `input` — input tensor of shape `[batch, input_size]`.
1283    /// * `state` — optional `(h, c)` each of shape `[batch, hidden_size]`.
1284    ///   If `None`, both are initialized to zeros.
1285    ///
1286    /// # Returns
1287    ///
1288    /// `(h', c')` each of shape `[batch, hidden_size]`.
1289    pub fn forward_cell(
1290        &self,
1291        input: &Tensor<T>,
1292        state: Option<(&Tensor<T>, &Tensor<T>)>,
1293    ) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
1294        if input.ndim() != 2 {
1295            return Err(FerrotorchError::InvalidArgument {
1296                message: format!(
1297                    "LSTMCell: expected 2-D input [batch, input_size], got shape {:?}",
1298                    input.shape()
1299                ),
1300            });
1301        }
1302        let batch = input.shape()[0];
1303        if input.shape()[1] != self.input_size {
1304            return Err(FerrotorchError::ShapeMismatch {
1305                message: format!(
1306                    "LSTMCell: input_size mismatch: expected {}, got {}",
1307                    self.input_size,
1308                    input.shape()[1]
1309                ),
1310            });
1311        }
1312
1313        let expected_h_shape = [batch, self.hidden_size];
1314
1315        let (h_state, c_state) = match state {
1316            Some((h0, c0)) => {
1317                if h0.shape() != expected_h_shape {
1318                    return Err(FerrotorchError::ShapeMismatch {
1319                        message: format!(
1320                            "LSTMCell: h shape mismatch: expected {:?}, got {:?}",
1321                            expected_h_shape,
1322                            h0.shape()
1323                        ),
1324                    });
1325                }
1326                if c0.shape() != expected_h_shape {
1327                    return Err(FerrotorchError::ShapeMismatch {
1328                        message: format!(
1329                            "LSTMCell: c shape mismatch: expected {:?}, got {:?}",
1330                            expected_h_shape,
1331                            c0.shape()
1332                        ),
1333                    });
1334                }
1335                (h0.clone(), c0.clone())
1336            }
1337            None => {
1338                let h0 = ferrotorch_core::zeros::<T>(&[batch, self.hidden_size])?;
1339                let c0 = ferrotorch_core::zeros::<T>(&[batch, self.hidden_size])?;
1340                (h0, c0)
1341            }
1342        };
1343
1344        let wih_t = transpose_2d(self.weight_ih.tensor())?;
1345        let whh_t = transpose_2d(self.weight_hh.tensor())?;
1346
1347        let xw = mm(input, &wih_t)?; // [batch, 4*hs]
1348        let hw = mm(&h_state, &whh_t)?; // [batch, 4*hs]
1349
1350        let bias_ih_2d = broadcast_bias_to_batch(&self.bias_ih, batch)?;
1351        let bias_hh_2d = broadcast_bias_to_batch(&self.bias_hh, batch)?;
1352
1353        let gates = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
1354
1355        // Split gates into i, f, g, o — each [batch, hidden_size].
1356        let gate_chunks = gates.chunk(4, 1)?;
1357        let i_gate = sigmoid(&gate_chunks[0])?;
1358        let f_gate = sigmoid(&gate_chunks[1])?;
1359        let g_gate = tanh(&gate_chunks[2])?;
1360        let o_gate = sigmoid(&gate_chunks[3])?;
1361
1362        // c' = f * c + i * g
1363        let c_new = add(&mul(&f_gate, &c_state)?, &mul(&i_gate, &g_gate)?)?;
1364
1365        // h' = o * tanh(c')
1366        let h_new = mul(&o_gate, &tanh(&c_new)?)?;
1367
1368        Ok((h_new, c_new))
1369    }
1370
1371    /// Number of expected input features.
1372    #[inline]
1373    pub fn input_size(&self) -> usize {
1374        self.input_size
1375    }
1376
1377    /// Number of features in the hidden state.
1378    #[inline]
1379    pub fn hidden_size(&self) -> usize {
1380        self.hidden_size
1381    }
1382}
1383
1384impl<T: Float> Module<T> for LSTMCell<T> {
1385    /// Forward with zero initial state. Returns `h'` only (drops `c'`).
1386    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1387        let (h, _c) = self.forward_cell(input, None)?;
1388        Ok(h)
1389    }
1390
1391    fn parameters(&self) -> Vec<&Parameter<T>> {
1392        vec![
1393            &self.weight_ih,
1394            &self.weight_hh,
1395            &self.bias_ih,
1396            &self.bias_hh,
1397        ]
1398    }
1399
1400    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1401        vec![
1402            &mut self.weight_ih,
1403            &mut self.weight_hh,
1404            &mut self.bias_ih,
1405            &mut self.bias_hh,
1406        ]
1407    }
1408
1409    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1410        vec![
1411            ("weight_ih".into(), &self.weight_ih),
1412            ("weight_hh".into(), &self.weight_hh),
1413            ("bias_ih".into(), &self.bias_ih),
1414            ("bias_hh".into(), &self.bias_hh),
1415        ]
1416    }
1417
1418    fn train(&mut self) {
1419        self.training = true;
1420    }
1421
1422    fn eval(&mut self) {
1423        self.training = false;
1424    }
1425
1426    fn is_training(&self) -> bool {
1427        self.training
1428    }
1429}
1430
1431// ===========================================================================
1432// GRUCell
1433// ===========================================================================
1434
1435/// A single-step GRU cell.
1436///
1437/// Computes:
1438/// ```text
1439/// r = sigmoid(x @ W_ir^T + b_ir + h @ W_hr^T + b_hr)
1440/// z = sigmoid(x @ W_iz^T + b_iz + h @ W_hz^T + b_hz)
1441/// n = tanh(x @ W_in^T + b_in + r * (h @ W_hn^T + b_hn))
1442/// h' = (1 - z) * n + z * h
1443/// ```
1444///
1445/// This is the equivalent of `torch.nn.GRUCell`.
1446#[derive(Debug)]
1447pub struct GRUCell<T: Float> {
1448    input_size: usize,
1449    hidden_size: usize,
1450    weight_ih: Parameter<T>,
1451    weight_hh: Parameter<T>,
1452    bias_ih: Parameter<T>,
1453    bias_hh: Parameter<T>,
1454    training: bool,
1455}
1456
1457impl<T: Float> GRUCell<T> {
1458    /// Create a new `GRUCell`.
1459    ///
1460    /// # Arguments
1461    ///
1462    /// * `input_size` — number of expected features in `x`.
1463    /// * `hidden_size` — number of features in the hidden state `h`.
1464    pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1465        if hidden_size == 0 {
1466            return Err(FerrotorchError::InvalidArgument {
1467                message: "GRUCell: hidden_size must be >= 1".into(),
1468            });
1469        }
1470        if input_size == 0 {
1471            return Err(FerrotorchError::InvalidArgument {
1472                message: "GRUCell: input_size must be >= 1".into(),
1473            });
1474        }
1475
1476        let k = 1.0 / (hidden_size as f64).sqrt();
1477        let gate_size = 3 * hidden_size;
1478
1479        let mut weight_ih = Parameter::zeros(&[gate_size, input_size])?;
1480        let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
1481        let mut bias_ih = Parameter::zeros(&[gate_size])?;
1482        let mut bias_hh = Parameter::zeros(&[gate_size])?;
1483
1484        init::uniform(&mut weight_ih, -k, k)?;
1485        init::uniform(&mut weight_hh, -k, k)?;
1486        init::zeros(&mut bias_ih)?;
1487        init::zeros(&mut bias_hh)?;
1488
1489        Ok(Self {
1490            input_size,
1491            hidden_size,
1492            weight_ih,
1493            weight_hh,
1494            bias_ih,
1495            bias_hh,
1496            training: true,
1497        })
1498    }
1499
1500    /// Forward pass for the GRU cell.
1501    ///
1502    /// # Arguments
1503    ///
1504    /// * `input` — input tensor of shape `[batch, input_size]`.
1505    /// * `h` — hidden state of shape `[batch, hidden_size]`. If `None`,
1506    ///   initialized to zeros.
1507    ///
1508    /// # Returns
1509    ///
1510    /// New hidden state `h'` of shape `[batch, hidden_size]`.
1511    pub fn forward_cell(
1512        &self,
1513        input: &Tensor<T>,
1514        h: Option<&Tensor<T>>,
1515    ) -> FerrotorchResult<Tensor<T>> {
1516        if input.ndim() != 2 {
1517            return Err(FerrotorchError::InvalidArgument {
1518                message: format!(
1519                    "GRUCell: expected 2-D input [batch, input_size], got shape {:?}",
1520                    input.shape()
1521                ),
1522            });
1523        }
1524        let batch = input.shape()[0];
1525        if input.shape()[1] != self.input_size {
1526            return Err(FerrotorchError::ShapeMismatch {
1527                message: format!(
1528                    "GRUCell: input_size mismatch: expected {}, got {}",
1529                    self.input_size,
1530                    input.shape()[1]
1531                ),
1532            });
1533        }
1534
1535        let hs = self.hidden_size;
1536
1537        let h_state = match h {
1538            Some(h0) => {
1539                if h0.shape() != [batch, hs] {
1540                    return Err(FerrotorchError::ShapeMismatch {
1541                        message: format!(
1542                            "GRUCell: h shape mismatch: expected {:?}, got {:?}",
1543                            [batch, hs],
1544                            h0.shape()
1545                        ),
1546                    });
1547                }
1548                h0.clone()
1549            }
1550            None => ferrotorch_core::zeros::<T>(&[batch, hs])?,
1551        };
1552
1553        let wih_t = transpose_2d(self.weight_ih.tensor())?;
1554        let whh_t = transpose_2d(self.weight_hh.tensor())?;
1555
1556        let xw = mm(input, &wih_t)?; // [batch, 3*hs]
1557        let hw = mm(&h_state, &whh_t)?; // [batch, 3*hs]
1558
1559        let bias_ih_2d = broadcast_bias_to_batch(&self.bias_ih, batch)?;
1560        let bias_hh_2d = broadcast_bias_to_batch(&self.bias_hh, batch)?;
1561
1562        let xw_b = add(&xw, &bias_ih_2d)?;
1563        let hw_b = add(&hw, &bias_hh_2d)?;
1564
1565        // Split [batch, 3*hs] into r, z, n components — each [batch, hs] —
1566        // via differentiable, device-aware chunk along dim=1.
1567        let xw_chunks = xw_b.chunk(3, 1)?;
1568        let hw_chunks = hw_b.chunk(3, 1)?;
1569        let rx = xw_chunks[0].clone();
1570        let zx = xw_chunks[1].clone();
1571        let nx = xw_chunks[2].clone();
1572        let rh = hw_chunks[0].clone();
1573        let zh = hw_chunks[1].clone();
1574        let nh = hw_chunks[2].clone();
1575
1576        // r = sigmoid(rx + rh), z = sigmoid(zx + zh)
1577        let r_gate = sigmoid(&add(&rx, &rh)?)?;
1578        let z_gate = sigmoid(&add(&zx, &zh)?)?;
1579
1580        // n = tanh(nx + r * nh)
1581        let r_nh = mul(&r_gate, &nh)?;
1582        let n_gate = tanh(&add(&nx, &r_nh)?)?;
1583
1584        // h' = (1 - z) * n + z * h
1585        let h_minus_n = sub(&h_state, &n_gate)?;
1586        let z_h_minus_n = mul(&z_gate, &h_minus_n)?;
1587        add(&n_gate, &z_h_minus_n)
1588    }
1589
1590    /// Number of expected input features.
1591    #[inline]
1592    pub fn input_size(&self) -> usize {
1593        self.input_size
1594    }
1595
1596    /// Number of features in the hidden state.
1597    #[inline]
1598    pub fn hidden_size(&self) -> usize {
1599        self.hidden_size
1600    }
1601}
1602
1603impl<T: Float> Module<T> for GRUCell<T> {
1604    /// Forward with zero initial hidden state.
1605    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1606        self.forward_cell(input, None)
1607    }
1608
1609    fn parameters(&self) -> Vec<&Parameter<T>> {
1610        vec![
1611            &self.weight_ih,
1612            &self.weight_hh,
1613            &self.bias_ih,
1614            &self.bias_hh,
1615        ]
1616    }
1617
1618    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1619        vec![
1620            &mut self.weight_ih,
1621            &mut self.weight_hh,
1622            &mut self.bias_ih,
1623            &mut self.bias_hh,
1624        ]
1625    }
1626
1627    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1628        vec![
1629            ("weight_ih".into(), &self.weight_ih),
1630            ("weight_hh".into(), &self.weight_hh),
1631            ("bias_ih".into(), &self.bias_ih),
1632            ("bias_hh".into(), &self.bias_hh),
1633        ]
1634    }
1635
1636    fn train(&mut self) {
1637        self.training = true;
1638    }
1639
1640    fn eval(&mut self) {
1641        self.training = false;
1642    }
1643
1644    fn is_training(&self) -> bool {
1645        self.training
1646    }
1647}
1648
1649// ===========================================================================
1650// RNN (multi-layer vanilla RNN)
1651// ===========================================================================
1652
1653/// Output type for RNN forward: `(output_sequence, h_n)`.
1654type RnnOutput<T> = (Tensor<T>, Tensor<T>);
1655
1656/// Per-layer parameters for the vanilla RNN.
1657#[derive(Debug, Clone)]
1658struct RNNLayerParams<T: Float> {
1659    /// Weight matrix for input-to-hidden: shape [hidden_size, input_size].
1660    weight_ih: Parameter<T>,
1661    /// Weight matrix for hidden-to-hidden: shape [hidden_size, hidden_size].
1662    weight_hh: Parameter<T>,
1663    /// Bias for input-to-hidden: shape [hidden_size].
1664    bias_ih: Parameter<T>,
1665    /// Bias for hidden-to-hidden: shape [hidden_size].
1666    bias_hh: Parameter<T>,
1667}
1668
1669/// A multi-layer vanilla RNN (Elman network).
1670///
1671/// For each element in the input sequence, each layer computes:
1672///
1673/// ```text
1674/// h_t = nonlinearity(x_t @ W_ih^T + b_ih + h_{t-1} @ W_hh^T + b_hh)
1675/// ```
1676///
1677/// where `nonlinearity` is either `tanh` (default) or `relu`.
1678///
1679/// This is the equivalent of `torch.nn.RNN`.
1680#[derive(Debug)]
1681pub struct RNN<T: Float> {
1682    input_size: usize,
1683    hidden_size: usize,
1684    num_layers: usize,
1685    nonlinearity: RNNNonlinearity,
1686    layers: Vec<RNNLayerParams<T>>,
1687    training: bool,
1688}
1689
1690impl<T: Float> RNN<T> {
1691    /// Create a new single-layer RNN with tanh nonlinearity.
1692    pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1693        Self::with_options(input_size, hidden_size, 1, RNNNonlinearity::Tanh)
1694    }
1695
1696    /// Create a new RNN with the specified number of layers and nonlinearity.
1697    ///
1698    /// # Arguments
1699    ///
1700    /// * `input_size` — number of expected features in the input `x`.
1701    /// * `hidden_size` — number of features in the hidden state `h`.
1702    /// * `num_layers` — number of stacked RNN layers (must be >= 1).
1703    /// * `nonlinearity` — activation function to use.
1704    pub fn with_options(
1705        input_size: usize,
1706        hidden_size: usize,
1707        num_layers: usize,
1708        nonlinearity: RNNNonlinearity,
1709    ) -> FerrotorchResult<Self> {
1710        if num_layers == 0 {
1711            return Err(FerrotorchError::InvalidArgument {
1712                message: "RNN: num_layers must be >= 1".into(),
1713            });
1714        }
1715        if hidden_size == 0 {
1716            return Err(FerrotorchError::InvalidArgument {
1717                message: "RNN: hidden_size must be >= 1".into(),
1718            });
1719        }
1720        if input_size == 0 {
1721            return Err(FerrotorchError::InvalidArgument {
1722                message: "RNN: input_size must be >= 1".into(),
1723            });
1724        }
1725
1726        let k = 1.0 / (hidden_size as f64).sqrt();
1727
1728        let mut layers = Vec::with_capacity(num_layers);
1729
1730        for layer_idx in 0..num_layers {
1731            let layer_input_size = if layer_idx == 0 {
1732                input_size
1733            } else {
1734                hidden_size
1735            };
1736
1737            let mut weight_ih = Parameter::zeros(&[hidden_size, layer_input_size])?;
1738            let mut weight_hh = Parameter::zeros(&[hidden_size, hidden_size])?;
1739            let mut bias_ih = Parameter::zeros(&[hidden_size])?;
1740            let mut bias_hh = Parameter::zeros(&[hidden_size])?;
1741
1742            init::uniform(&mut weight_ih, -k, k)?;
1743            init::uniform(&mut weight_hh, -k, k)?;
1744            init::zeros(&mut bias_ih)?;
1745            init::zeros(&mut bias_hh)?;
1746
1747            layers.push(RNNLayerParams {
1748                weight_ih,
1749                weight_hh,
1750                bias_ih,
1751                bias_hh,
1752            });
1753        }
1754
1755        Ok(Self {
1756            input_size,
1757            hidden_size,
1758            num_layers,
1759            nonlinearity,
1760            layers,
1761            training: true,
1762        })
1763    }
1764
1765    /// Forward pass with explicit hidden state.
1766    ///
1767    /// # Arguments
1768    ///
1769    /// * `input` — input tensor of shape `[batch, seq_len, input_size]`.
1770    /// * `h_0` — optional hidden state of shape `[num_layers, batch, hidden_size]`.
1771    ///   If `None`, initialized to zeros.
1772    ///
1773    /// # Returns
1774    ///
1775    /// A tuple `(output, h_n)` where:
1776    /// - `output` has shape `[batch, seq_len, hidden_size]` (last layer outputs).
1777    /// - `h_n` has shape `[num_layers, batch, hidden_size]`.
1778    pub fn forward_with_state(
1779        &self,
1780        input: &Tensor<T>,
1781        h_0: Option<&Tensor<T>>,
1782    ) -> FerrotorchResult<RnnOutput<T>> {
1783        if input.ndim() != 3 {
1784            return Err(FerrotorchError::InvalidArgument {
1785                message: format!(
1786                    "RNN: expected 3-D input [batch, seq_len, input_size], got shape {:?}",
1787                    input.shape()
1788                ),
1789            });
1790        }
1791
1792        let batch = input.shape()[0];
1793        let seq_len = input.shape()[1];
1794        let hs = self.hidden_size;
1795
1796        if input.shape()[2] != self.input_size {
1797            return Err(FerrotorchError::ShapeMismatch {
1798                message: format!(
1799                    "RNN: input_size mismatch: expected {}, got {}",
1800                    self.input_size,
1801                    input.shape()[2]
1802                ),
1803            });
1804        }
1805
1806        // Initialize hidden state.
1807        let h_init = match h_0 {
1808            Some(h0) => {
1809                let expected_shape = [self.num_layers, batch, hs];
1810                if h0.shape() != expected_shape {
1811                    return Err(FerrotorchError::ShapeMismatch {
1812                        message: format!(
1813                            "RNN: h_0 shape mismatch: expected {:?}, got {:?}",
1814                            expected_shape,
1815                            h0.shape()
1816                        ),
1817                    });
1818                }
1819                h0.clone()
1820            }
1821            None => ferrotorch_core::zeros::<T>(&[self.num_layers, batch, hs])?,
1822        };
1823
1824        // Extract per-timestep input slices via device-aware narrow + squeeze.
1825        // input is [batch, seq_len, input_size]; per-timestep is [batch, input_size].
1826        let mut timestep_inputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
1827        for t in 0..seq_len {
1828            let slice = input.narrow(1, t, 1)?; // [batch, 1, input_size]
1829            timestep_inputs.push(slice.squeeze_t(1)?); // [batch, input_size]
1830        }
1831
1832        // Extract per-layer initial hidden states via narrow + squeeze.
1833        let mut layer_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
1834        for l in 0..self.num_layers {
1835            layer_h.push(h_init.narrow(0, l, 1)?.squeeze_t(0)?);
1836        }
1837
1838        // Run the RNN forward pass.
1839        let mut layer_outputs: Vec<Tensor<T>> = timestep_inputs;
1840        let mut final_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
1841
1842        for (l, params) in self.layers.iter().enumerate() {
1843            let mut h = layer_h[l].clone();
1844            let mut next_layer_outputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
1845
1846            // Hoist + materialize the contiguous transpose once per layer so
1847            // the per-step `mm` does not re-copy the constant non-contiguous
1848            // transposed weight each timestep (#1680). `contiguous()` is a
1849            // differentiable identity-on-values op — value/grad preserving.
1850            let wih_t = transpose_2d(params.weight_ih.tensor())?.contiguous()?;
1851            let whh_t = transpose_2d(params.weight_hh.tensor())?.contiguous()?;
1852
1853            // Batch the input-to-hidden projection into ONE GEMM across all
1854            // timesteps (#1690): the seq_len small [batch, in]@[in, hs] input
1855            // GEMMs collapse to a single [seq_len*batch, in]@[in, hs] GEMM;
1856            // only the recurrent h@W_hh^T stays per-step. Pure reassociation —
1857            // value/grad-identical, pinned by the #1690 live-torch parity
1858            // test. Mirrors upstream `RNN.cpp:863-869`.
1859            let bias_ih_2d = broadcast_bias_to_batch(&params.bias_ih, batch)?;
1860            let bias_hh_2d = broadcast_bias_to_batch(&params.bias_hh, batch)?;
1861            let xw_all = batched_input_projection(&layer_outputs, &wih_t)?;
1862
1863            for (t, _x_t) in layer_outputs.iter().enumerate() {
1864                let xw = xw_all.narrow(0, t * batch, batch)?; // [batch, hs]
1865                let hw = mm(&h, &whh_t)?; // [batch, hs]
1866
1867                let pre_act = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
1868
1869                let h_new = match self.nonlinearity {
1870                    RNNNonlinearity::Tanh => tanh(&pre_act)?,
1871                    RNNNonlinearity::ReLU => relu(&pre_act)?,
1872                };
1873
1874                next_layer_outputs.push(h_new.clone());
1875                h = h_new;
1876            }
1877
1878            final_h.push(h);
1879            layer_outputs = next_layer_outputs;
1880        }
1881
1882        // Assemble output: [batch, seq_len, hidden_size] from the last layer.
1883        // Each layer_outputs[t] is [batch, hs]; cat-along-1 + reshape mirrors
1884        // the LSTM/GRU output assembly and stays on-device.
1885        let output = if seq_len == 1 {
1886            reshape(&layer_outputs[0], &[batch as isize, 1, hs as isize])?
1887        } else {
1888            let stacked = cat(&layer_outputs, 1)?;
1889            reshape(&stacked, &[batch as isize, seq_len as isize, hs as isize])?
1890        };
1891
1892        // Assemble h_n: [num_layers, batch, hidden_size].
1893        let h_n = if self.num_layers == 1 {
1894            reshape(&final_h[0], &[1, batch as isize, hs as isize])?
1895        } else {
1896            let h_stacked = cat(&final_h, 0)?;
1897            reshape(
1898                &h_stacked,
1899                &[self.num_layers as isize, batch as isize, hs as isize],
1900            )?
1901        };
1902
1903        Ok((output, h_n))
1904    }
1905
1906    /// Number of expected input features.
1907    #[inline]
1908    pub fn input_size(&self) -> usize {
1909        self.input_size
1910    }
1911
1912    /// Number of features in the hidden state.
1913    #[inline]
1914    pub fn hidden_size(&self) -> usize {
1915        self.hidden_size
1916    }
1917
1918    /// Number of stacked RNN layers.
1919    #[inline]
1920    pub fn num_layers(&self) -> usize {
1921        self.num_layers
1922    }
1923
1924    /// The nonlinearity used by this module.
1925    #[inline]
1926    pub fn nonlinearity(&self) -> RNNNonlinearity {
1927        self.nonlinearity
1928    }
1929}
1930
1931impl<T: Float> Module<T> for RNN<T> {
1932    /// Forward pass using the `Module` interface (no explicit hidden state).
1933    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1934        let (output, _) = self.forward_with_state(input, None)?;
1935        Ok(output)
1936    }
1937
1938    fn parameters(&self) -> Vec<&Parameter<T>> {
1939        let mut params = Vec::with_capacity(self.num_layers * 4);
1940        for layer in &self.layers {
1941            params.push(&layer.weight_ih);
1942            params.push(&layer.weight_hh);
1943            params.push(&layer.bias_ih);
1944            params.push(&layer.bias_hh);
1945        }
1946        params
1947    }
1948
1949    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1950        let mut params = Vec::with_capacity(self.num_layers * 4);
1951        for layer in &mut self.layers {
1952            params.push(&mut layer.weight_ih);
1953            params.push(&mut layer.weight_hh);
1954            params.push(&mut layer.bias_ih);
1955            params.push(&mut layer.bias_hh);
1956        }
1957        params
1958    }
1959
1960    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1961        let mut params = Vec::with_capacity(self.num_layers * 4);
1962        for (i, layer) in self.layers.iter().enumerate() {
1963            params.push((format!("layers.{i}.weight_ih"), &layer.weight_ih));
1964            params.push((format!("layers.{i}.weight_hh"), &layer.weight_hh));
1965            params.push((format!("layers.{i}.bias_ih"), &layer.bias_ih));
1966            params.push((format!("layers.{i}.bias_hh"), &layer.bias_hh));
1967        }
1968        params
1969    }
1970
1971    fn train(&mut self) {
1972        self.training = true;
1973    }
1974
1975    fn eval(&mut self) {
1976        self.training = false;
1977    }
1978
1979    fn is_training(&self) -> bool {
1980        self.training
1981    }
1982}
1983
1984// ===========================================================================
1985// Tests
1986// ===========================================================================
1987
1988#[cfg(test)]
1989mod tests {
1990    use super::*;
1991
1992    // -----------------------------------------------------------------------
1993    // Construction
1994    // -----------------------------------------------------------------------
1995
1996    #[test]
1997    fn test_lstm_new_basic() {
1998        let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
1999        assert_eq!(lstm.input_size(), 10);
2000        assert_eq!(lstm.hidden_size(), 20);
2001        assert_eq!(lstm.num_layers(), 1);
2002    }
2003
2004    #[test]
2005    fn test_lstm_parameter_count() {
2006        let lstm = LSTM::<f32>::new(10, 20, 2).unwrap();
2007        // Layer 0: weight_ih [80,10], weight_hh [80,20], bias_ih [80], bias_hh [80]
2008        // Layer 1: weight_ih [80,20], weight_hh [80,20], bias_ih [80], bias_hh [80]
2009        let params = lstm.parameters();
2010        assert_eq!(params.len(), 8); // 4 per layer * 2 layers
2011    }
2012
2013    #[test]
2014    fn test_lstm_parameter_shapes() {
2015        let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
2016        let params = lstm.parameters();
2017        // weight_ih: [80, 10]
2018        assert_eq!(params[0].shape(), &[80, 10]);
2019        // weight_hh: [80, 20]
2020        assert_eq!(params[1].shape(), &[80, 20]);
2021        // bias_ih: [80]
2022        assert_eq!(params[2].shape(), &[80]);
2023        // bias_hh: [80]
2024        assert_eq!(params[3].shape(), &[80]);
2025    }
2026
2027    #[test]
2028    fn test_lstm_new_invalid_num_layers() {
2029        assert!(LSTM::<f32>::new(10, 20, 0).is_err());
2030    }
2031
2032    #[test]
2033    fn test_lstm_new_invalid_hidden_size() {
2034        assert!(LSTM::<f32>::new(10, 0, 1).is_err());
2035    }
2036
2037    #[test]
2038    fn test_lstm_new_invalid_input_size() {
2039        assert!(LSTM::<f32>::new(0, 20, 1).is_err());
2040    }
2041
2042    // -----------------------------------------------------------------------
2043    // Weight initialization
2044    // -----------------------------------------------------------------------
2045
2046    #[test]
2047    fn test_lstm_weight_init_range() {
2048        let hs = 100;
2049        let lstm = LSTM::<f32>::new(50, hs, 1).unwrap();
2050        let k = 1.0 / (hs as f32).sqrt();
2051        let params = lstm.parameters();
2052
2053        // Weights should be in U(-k, k).
2054        for param in &params[..2] {
2055            let data = param.data().unwrap();
2056            for &v in data {
2057                assert!(
2058                    v.abs() <= k + 0.01,
2059                    "weight value {v} exceeds expected range [-{k}, {k}]"
2060                );
2061            }
2062        }
2063
2064        // Biases should be zeros.
2065        for param in &params[2..4] {
2066            let data = param.data().unwrap();
2067            assert!(
2068                data.iter().all(|&v| v == 0.0),
2069                "bias should be initialized to zeros"
2070            );
2071        }
2072    }
2073
2074    // -----------------------------------------------------------------------
2075    // Forward pass — output shapes
2076    // -----------------------------------------------------------------------
2077
2078    #[test]
2079    fn test_lstm_forward_output_shape() {
2080        let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
2081        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap(); // [B=2, T=5, F=10]
2082
2083        let (output, (h_n, c_n)) = lstm.forward_with_state(&input, None).unwrap();
2084
2085        assert_eq!(output.shape(), &[2, 5, 20]); // [B, T, hidden]
2086        assert_eq!(h_n.shape(), &[1, 2, 20]); // [layers, B, hidden]
2087        assert_eq!(c_n.shape(), &[1, 2, 20]);
2088    }
2089
2090    #[test]
2091    fn test_lstm_forward_multi_layer_shapes() {
2092        let lstm = LSTM::<f32>::new(8, 16, 3).unwrap();
2093        let input = ferrotorch_core::zeros::<f32>(&[4, 7, 8]).unwrap(); // [B=4, T=7, F=8]
2094
2095        let (output, (h_n, c_n)) = lstm.forward_with_state(&input, None).unwrap();
2096
2097        assert_eq!(output.shape(), &[4, 7, 16]); // [B, T, hidden]
2098        assert_eq!(h_n.shape(), &[3, 4, 16]); // [layers, B, hidden]
2099        assert_eq!(c_n.shape(), &[3, 4, 16]);
2100    }
2101
2102    #[test]
2103    fn test_lstm_module_forward_shape() {
2104        let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
2105        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
2106
2107        let output = lstm.forward(&input).unwrap();
2108        assert_eq!(output.shape(), &[2, 5, 20]);
2109    }
2110
2111    // -----------------------------------------------------------------------
2112    // Forward pass — basic sanity
2113    // -----------------------------------------------------------------------
2114
2115    #[test]
2116    fn test_lstm_forward_does_not_error() {
2117        let lstm = LSTM::<f32>::new(4, 8, 2).unwrap();
2118        let input = ferrotorch_core::randn::<f32>(&[3, 10, 4]).unwrap();
2119
2120        let result = lstm.forward_with_state(&input, None);
2121        assert!(
2122            result.is_ok(),
2123            "forward should not error: {:?}",
2124            result.err()
2125        );
2126    }
2127
2128    #[test]
2129    fn test_lstm_forward_nonzero_output() {
2130        // With random weights and random input, the output should not be all zeros.
2131        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2132        let input = ferrotorch_core::randn::<f32>(&[1, 3, 4]).unwrap();
2133
2134        let (output, _) = lstm.forward_with_state(&input, None).unwrap();
2135        let data = output.data().unwrap();
2136        let any_nonzero = data.iter().any(|&v| v.abs() > 1e-10);
2137        assert!(any_nonzero, "output should have non-zero values");
2138    }
2139
2140    #[test]
2141    fn test_lstm_forward_seq_len_1() {
2142        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2143        let input = ferrotorch_core::zeros::<f32>(&[1, 1, 4]).unwrap();
2144
2145        let (output, (h_n, c_n)) = lstm.forward_with_state(&input, None).unwrap();
2146        assert_eq!(output.shape(), &[1, 1, 8]);
2147        assert_eq!(h_n.shape(), &[1, 1, 8]);
2148        assert_eq!(c_n.shape(), &[1, 1, 8]);
2149    }
2150
2151    // -----------------------------------------------------------------------
2152    // Forward with explicit state
2153    // -----------------------------------------------------------------------
2154
2155    #[test]
2156    fn test_lstm_forward_with_initial_state() {
2157        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2158
2159        let h0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2160        let c0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2161        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2162
2163        let result = lstm.forward_with_state(&input, Some((&h0, &c0)));
2164        assert!(result.is_ok());
2165    }
2166
2167    #[test]
2168    fn test_lstm_forward_state_shape_mismatch() {
2169        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2170
2171        // Wrong batch size in h0.
2172        let h0 = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2173        let c0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2174        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2175
2176        assert!(lstm.forward_with_state(&input, Some((&h0, &c0))).is_err());
2177    }
2178
2179    #[test]
2180    fn test_lstm_forward_input_wrong_ndim() {
2181        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2182        let input = ferrotorch_core::zeros::<f32>(&[10, 4]).unwrap(); // 2-D, not 3-D
2183        assert!(lstm.forward_with_state(&input, None).is_err());
2184    }
2185
2186    #[test]
2187    fn test_lstm_forward_input_size_mismatch() {
2188        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2189        let input = ferrotorch_core::zeros::<f32>(&[1, 5, 7]).unwrap(); // input_size=7 != 4
2190        assert!(lstm.forward_with_state(&input, None).is_err());
2191    }
2192
2193    // -----------------------------------------------------------------------
2194    // Module trait
2195    // -----------------------------------------------------------------------
2196
2197    #[test]
2198    fn test_lstm_named_parameters() {
2199        let lstm = LSTM::<f32>::new(4, 8, 2).unwrap();
2200        let named = lstm.named_parameters();
2201        assert_eq!(named.len(), 8);
2202        assert_eq!(named[0].0, "layers.0.weight_ih");
2203        assert_eq!(named[1].0, "layers.0.weight_hh");
2204        assert_eq!(named[2].0, "layers.0.bias_ih");
2205        assert_eq!(named[3].0, "layers.0.bias_hh");
2206        assert_eq!(named[4].0, "layers.1.weight_ih");
2207        assert_eq!(named[5].0, "layers.1.weight_hh");
2208        assert_eq!(named[6].0, "layers.1.bias_ih");
2209        assert_eq!(named[7].0, "layers.1.bias_hh");
2210    }
2211
2212    #[test]
2213    fn test_lstm_train_eval() {
2214        let mut lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2215        assert!(lstm.is_training());
2216        lstm.eval();
2217        assert!(!lstm.is_training());
2218        lstm.train();
2219        assert!(lstm.is_training());
2220    }
2221
2222    #[test]
2223    fn test_lstm_all_parameters_require_grad() {
2224        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2225        for param in lstm.parameters() {
2226            assert!(param.requires_grad());
2227        }
2228    }
2229
2230    #[test]
2231    fn test_lstm_is_send_sync() {
2232        fn assert_send_sync<T: Send + Sync>() {}
2233        assert_send_sync::<LSTM<f32>>();
2234        assert_send_sync::<LSTM<f64>>();
2235    }
2236
2237    // -----------------------------------------------------------------------
2238    // Multi-layer: second layer input_size equals hidden_size
2239    // -----------------------------------------------------------------------
2240
2241    #[test]
2242    fn test_lstm_multi_layer_weight_shapes() {
2243        let lstm = LSTM::<f32>::new(10, 20, 3).unwrap();
2244        let params = lstm.parameters();
2245
2246        // Layer 0: weight_ih [80, 10] (input_size=10)
2247        assert_eq!(params[0].shape(), &[80, 10]);
2248        // Layer 0: weight_hh [80, 20]
2249        assert_eq!(params[1].shape(), &[80, 20]);
2250
2251        // Layer 1: weight_ih [80, 20] (input_size=hidden_size=20)
2252        assert_eq!(params[4].shape(), &[80, 20]);
2253        // Layer 1: weight_hh [80, 20]
2254        assert_eq!(params[5].shape(), &[80, 20]);
2255
2256        // Layer 2: weight_ih [80, 20]
2257        assert_eq!(params[8].shape(), &[80, 20]);
2258    }
2259
2260    // -----------------------------------------------------------------------
2261    // State dict roundtrip
2262    // -----------------------------------------------------------------------
2263
2264    #[test]
2265    fn test_lstm_state_dict_roundtrip() {
2266        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2267        let sd = lstm.state_dict();
2268        assert_eq!(sd.len(), 4);
2269        assert!(sd.contains_key("layers.0.weight_ih"));
2270        assert!(sd.contains_key("layers.0.weight_hh"));
2271        assert!(sd.contains_key("layers.0.bias_ih"));
2272        assert!(sd.contains_key("layers.0.bias_hh"));
2273
2274        let mut lstm2 = LSTM::<f32>::new(4, 8, 1).unwrap();
2275        lstm2.load_state_dict(&sd, true).unwrap();
2276    }
2277
2278    // -----------------------------------------------------------------------
2279    // Consistency: feeding the same input twice gives the same output
2280    // -----------------------------------------------------------------------
2281
2282    #[test]
2283    fn test_lstm_deterministic() {
2284        let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2285        let input = ferrotorch_core::from_slice::<f32>(
2286            &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
2287            &[1, 2, 4],
2288        )
2289        .unwrap();
2290
2291        let (out1, _) = lstm.forward_with_state(&input, None).unwrap();
2292        let (out2, _) = lstm.forward_with_state(&input, None).unwrap();
2293
2294        let d1 = out1.data().unwrap();
2295        let d2 = out2.data().unwrap();
2296        for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
2297            assert!(
2298                (a - b).abs() < 1e-6,
2299                "output mismatch at index {i}: {a} vs {b}"
2300            );
2301        }
2302    }
2303
2304    // =======================================================================
2305    // GRU tests
2306    // =======================================================================
2307
2308    // -----------------------------------------------------------------------
2309    // Construction
2310    // -----------------------------------------------------------------------
2311
2312    #[test]
2313    fn test_gru_new_basic() {
2314        let gru = GRU::<f32>::new(10, 20).unwrap();
2315        assert_eq!(gru.input_size(), 10);
2316        assert_eq!(gru.hidden_size(), 20);
2317        assert_eq!(gru.num_layers(), 1);
2318    }
2319
2320    #[test]
2321    fn test_gru_with_num_layers() {
2322        let gru = GRU::<f32>::with_num_layers(10, 20, 3).unwrap();
2323        assert_eq!(gru.num_layers(), 3);
2324    }
2325
2326    #[test]
2327    fn test_gru_parameter_count() {
2328        let gru = GRU::<f32>::with_num_layers(10, 20, 2).unwrap();
2329        // Layer 0: weight_ih [60,10], weight_hh [60,20], bias_ih [60], bias_hh [60]
2330        // Layer 1: weight_ih [60,20], weight_hh [60,20], bias_ih [60], bias_hh [60]
2331        let params = gru.parameters();
2332        assert_eq!(params.len(), 8); // 4 per layer * 2 layers
2333    }
2334
2335    #[test]
2336    fn test_gru_parameter_shapes() {
2337        let gru = GRU::<f32>::new(10, 20).unwrap();
2338        let params = gru.parameters();
2339        // weight_ih: [60, 10] (3 * hidden_size = 60)
2340        assert_eq!(params[0].shape(), &[60, 10]);
2341        // weight_hh: [60, 20]
2342        assert_eq!(params[1].shape(), &[60, 20]);
2343        // bias_ih: [60]
2344        assert_eq!(params[2].shape(), &[60]);
2345        // bias_hh: [60]
2346        assert_eq!(params[3].shape(), &[60]);
2347    }
2348
2349    #[test]
2350    fn test_gru_new_invalid_num_layers() {
2351        assert!(GRU::<f32>::with_num_layers(10, 20, 0).is_err());
2352    }
2353
2354    #[test]
2355    fn test_gru_new_invalid_hidden_size() {
2356        assert!(GRU::<f32>::new(10, 0).is_err());
2357    }
2358
2359    #[test]
2360    fn test_gru_new_invalid_input_size() {
2361        assert!(GRU::<f32>::new(0, 20).is_err());
2362    }
2363
2364    // -----------------------------------------------------------------------
2365    // Weight initialization
2366    // -----------------------------------------------------------------------
2367
2368    #[test]
2369    fn test_gru_weight_init_range() {
2370        let hs = 100;
2371        let gru = GRU::<f32>::new(50, hs).unwrap();
2372        let k = 1.0 / (hs as f32).sqrt();
2373        let params = gru.parameters();
2374
2375        // Weights should be in U(-k, k).
2376        for param in &params[..2] {
2377            let data = param.data().unwrap();
2378            for &v in data {
2379                assert!(
2380                    v.abs() <= k + 0.01,
2381                    "weight value {v} exceeds expected range [-{k}, {k}]"
2382                );
2383            }
2384        }
2385
2386        // Biases should be zeros.
2387        for param in &params[2..4] {
2388            let data = param.data().unwrap();
2389            assert!(
2390                data.iter().all(|&v| v == 0.0),
2391                "bias should be initialized to zeros"
2392            );
2393        }
2394    }
2395
2396    // -----------------------------------------------------------------------
2397    // Forward pass — output shapes
2398    // -----------------------------------------------------------------------
2399
2400    #[test]
2401    fn test_gru_forward_output_shape() {
2402        let gru = GRU::<f32>::new(10, 20).unwrap();
2403        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
2404
2405        let (output, h_n) = gru.forward(&input, None).unwrap();
2406
2407        assert_eq!(output.shape(), &[2, 5, 20]); // [B, T, hidden]
2408        assert_eq!(h_n.shape(), &[1, 2, 20]); // [layers, B, hidden]
2409    }
2410
2411    #[test]
2412    fn test_gru_forward_multi_layer_shapes() {
2413        let gru = GRU::<f32>::with_num_layers(8, 16, 3).unwrap();
2414        let input = ferrotorch_core::zeros::<f32>(&[4, 7, 8]).unwrap();
2415
2416        let (output, h_n) = gru.forward(&input, None).unwrap();
2417
2418        assert_eq!(output.shape(), &[4, 7, 16]);
2419        assert_eq!(h_n.shape(), &[3, 4, 16]);
2420    }
2421
2422    #[test]
2423    fn test_gru_module_forward_shape() {
2424        let gru = GRU::<f32>::new(10, 20).unwrap();
2425        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
2426
2427        let output = <GRU<f32> as Module<f32>>::forward(&gru, &input).unwrap();
2428        assert_eq!(output.shape(), &[2, 5, 20]);
2429    }
2430
2431    // -----------------------------------------------------------------------
2432    // Forward pass — basic sanity
2433    // -----------------------------------------------------------------------
2434
2435    #[test]
2436    fn test_gru_forward_does_not_error() {
2437        let gru = GRU::<f32>::with_num_layers(4, 8, 2).unwrap();
2438        let input = ferrotorch_core::randn::<f32>(&[3, 10, 4]).unwrap();
2439
2440        let result = gru.forward(&input, None);
2441        assert!(
2442            result.is_ok(),
2443            "forward should not error: {:?}",
2444            result.err()
2445        );
2446    }
2447
2448    #[test]
2449    fn test_gru_forward_nonzero_output() {
2450        let gru = GRU::<f32>::new(4, 8).unwrap();
2451        let input = ferrotorch_core::randn::<f32>(&[1, 3, 4]).unwrap();
2452
2453        let (output, _) = gru.forward(&input, None).unwrap();
2454        let data = output.data().unwrap();
2455        let any_nonzero = data.iter().any(|&v| v.abs() > 1e-10);
2456        assert!(any_nonzero, "output should have non-zero values");
2457    }
2458
2459    #[test]
2460    fn test_gru_forward_seq_len_1() {
2461        let gru = GRU::<f32>::new(4, 8).unwrap();
2462        let input = ferrotorch_core::zeros::<f32>(&[1, 1, 4]).unwrap();
2463
2464        let (output, h_n) = gru.forward(&input, None).unwrap();
2465        assert_eq!(output.shape(), &[1, 1, 8]);
2466        assert_eq!(h_n.shape(), &[1, 1, 8]);
2467    }
2468
2469    // -----------------------------------------------------------------------
2470    // Forward with explicit state
2471    // -----------------------------------------------------------------------
2472
2473    #[test]
2474    fn test_gru_forward_with_initial_state() {
2475        let gru = GRU::<f32>::new(4, 8).unwrap();
2476
2477        let h0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2478        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2479
2480        let result = gru.forward(&input, Some(&h0));
2481        assert!(result.is_ok());
2482    }
2483
2484    #[test]
2485    fn test_gru_forward_state_shape_mismatch() {
2486        let gru = GRU::<f32>::new(4, 8).unwrap();
2487
2488        // Wrong batch size in h0.
2489        let h0 = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2490        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2491
2492        assert!(gru.forward(&input, Some(&h0)).is_err());
2493    }
2494
2495    #[test]
2496    fn test_gru_forward_input_wrong_ndim() {
2497        let gru = GRU::<f32>::new(4, 8).unwrap();
2498        let input = ferrotorch_core::zeros::<f32>(&[10, 4]).unwrap();
2499        assert!(gru.forward(&input, None).is_err());
2500    }
2501
2502    #[test]
2503    fn test_gru_forward_input_size_mismatch() {
2504        let gru = GRU::<f32>::new(4, 8).unwrap();
2505        let input = ferrotorch_core::zeros::<f32>(&[1, 5, 7]).unwrap();
2506        assert!(gru.forward(&input, None).is_err());
2507    }
2508
2509    // -----------------------------------------------------------------------
2510    // Module trait (GRU)
2511    // -----------------------------------------------------------------------
2512
2513    #[test]
2514    fn test_gru_named_parameters() {
2515        let gru = GRU::<f32>::with_num_layers(4, 8, 2).unwrap();
2516        let named = gru.named_parameters();
2517        assert_eq!(named.len(), 8);
2518        assert_eq!(named[0].0, "layers.0.weight_ih");
2519        assert_eq!(named[1].0, "layers.0.weight_hh");
2520        assert_eq!(named[2].0, "layers.0.bias_ih");
2521        assert_eq!(named[3].0, "layers.0.bias_hh");
2522        assert_eq!(named[4].0, "layers.1.weight_ih");
2523        assert_eq!(named[5].0, "layers.1.weight_hh");
2524        assert_eq!(named[6].0, "layers.1.bias_ih");
2525        assert_eq!(named[7].0, "layers.1.bias_hh");
2526    }
2527
2528    #[test]
2529    fn test_gru_train_eval() {
2530        let mut gru = GRU::<f32>::new(4, 8).unwrap();
2531        assert!(gru.is_training());
2532        gru.eval();
2533        assert!(!gru.is_training());
2534        gru.train();
2535        assert!(gru.is_training());
2536    }
2537
2538    #[test]
2539    fn test_gru_all_parameters_require_grad() {
2540        let gru = GRU::<f32>::new(4, 8).unwrap();
2541        for param in gru.parameters() {
2542            assert!(param.requires_grad());
2543        }
2544    }
2545
2546    #[test]
2547    fn test_gru_is_send_sync() {
2548        fn assert_send_sync<T: Send + Sync>() {}
2549        assert_send_sync::<GRU<f32>>();
2550        assert_send_sync::<GRU<f64>>();
2551    }
2552
2553    // -----------------------------------------------------------------------
2554    // Multi-layer weight shapes (GRU)
2555    // -----------------------------------------------------------------------
2556
2557    #[test]
2558    fn test_gru_multi_layer_weight_shapes() {
2559        let gru = GRU::<f32>::with_num_layers(10, 20, 3).unwrap();
2560        let params = gru.parameters();
2561
2562        // Layer 0: weight_ih [60, 10] (input_size=10)
2563        assert_eq!(params[0].shape(), &[60, 10]);
2564        // Layer 0: weight_hh [60, 20]
2565        assert_eq!(params[1].shape(), &[60, 20]);
2566
2567        // Layer 1: weight_ih [60, 20] (input_size=hidden_size=20)
2568        assert_eq!(params[4].shape(), &[60, 20]);
2569        // Layer 1: weight_hh [60, 20]
2570        assert_eq!(params[5].shape(), &[60, 20]);
2571
2572        // Layer 2: weight_ih [60, 20]
2573        assert_eq!(params[8].shape(), &[60, 20]);
2574    }
2575
2576    // -----------------------------------------------------------------------
2577    // State dict roundtrip (GRU)
2578    // -----------------------------------------------------------------------
2579
2580    #[test]
2581    fn test_gru_state_dict_roundtrip() {
2582        let gru = GRU::<f32>::new(4, 8).unwrap();
2583        let sd = gru.state_dict();
2584        assert_eq!(sd.len(), 4);
2585        assert!(sd.contains_key("layers.0.weight_ih"));
2586        assert!(sd.contains_key("layers.0.weight_hh"));
2587        assert!(sd.contains_key("layers.0.bias_ih"));
2588        assert!(sd.contains_key("layers.0.bias_hh"));
2589
2590        let mut gru2 = GRU::<f32>::new(4, 8).unwrap();
2591        gru2.load_state_dict(&sd, true).unwrap();
2592    }
2593
2594    // -----------------------------------------------------------------------
2595    // Determinism (GRU)
2596    // -----------------------------------------------------------------------
2597
2598    #[test]
2599    fn test_gru_deterministic() {
2600        let gru = GRU::<f32>::new(4, 8).unwrap();
2601        let input = ferrotorch_core::from_slice::<f32>(
2602            &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
2603            &[1, 2, 4],
2604        )
2605        .unwrap();
2606
2607        let (out1, _) = gru.forward(&input, None).unwrap();
2608        let (out2, _) = gru.forward(&input, None).unwrap();
2609
2610        let d1 = out1.data().unwrap();
2611        let d2 = out2.data().unwrap();
2612        for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
2613            assert!(
2614                (a - b).abs() < 1e-6,
2615                "output mismatch at index {i}: {a} vs {b}"
2616            );
2617        }
2618    }
2619
2620    // =======================================================================
2621    // RNNCell tests
2622    // =======================================================================
2623
2624    #[test]
2625    fn test_rnn_cell_new_basic() {
2626        let cell = RNNCell::<f32>::new(10, 20).unwrap();
2627        assert_eq!(cell.input_size(), 10);
2628        assert_eq!(cell.hidden_size(), 20);
2629        assert_eq!(cell.nonlinearity(), RNNNonlinearity::Tanh);
2630    }
2631
2632    #[test]
2633    fn test_rnn_cell_relu() {
2634        let cell = RNNCell::<f32>::with_nonlinearity(10, 20, RNNNonlinearity::ReLU).unwrap();
2635        assert_eq!(cell.nonlinearity(), RNNNonlinearity::ReLU);
2636    }
2637
2638    #[test]
2639    fn test_rnn_cell_invalid_sizes() {
2640        assert!(RNNCell::<f32>::new(0, 20).is_err());
2641        assert!(RNNCell::<f32>::new(10, 0).is_err());
2642    }
2643
2644    #[test]
2645    fn test_rnn_cell_parameter_shapes() {
2646        let cell = RNNCell::<f32>::new(10, 20).unwrap();
2647        let params = cell.parameters();
2648        assert_eq!(params.len(), 4);
2649        assert_eq!(params[0].shape(), &[20, 10]); // weight_ih
2650        assert_eq!(params[1].shape(), &[20, 20]); // weight_hh
2651        assert_eq!(params[2].shape(), &[20]); // bias_ih
2652        assert_eq!(params[3].shape(), &[20]); // bias_hh
2653    }
2654
2655    #[test]
2656    fn test_rnn_cell_forward_output_shape() {
2657        let cell = RNNCell::<f32>::new(10, 20).unwrap();
2658        let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2659        let h = cell.forward_cell(&x, None).unwrap();
2660        assert_eq!(h.shape(), &[3, 20]);
2661    }
2662
2663    #[test]
2664    fn test_rnn_cell_forward_with_hidden() {
2665        let cell = RNNCell::<f32>::new(10, 20).unwrap();
2666        let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2667        let h0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2668        let h = cell.forward_cell(&x, Some(&h0)).unwrap();
2669        assert_eq!(h.shape(), &[3, 20]);
2670    }
2671
2672    #[test]
2673    fn test_rnn_cell_forward_nonzero() {
2674        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2675        let x = ferrotorch_core::randn::<f32>(&[1, 4]).unwrap();
2676        let h = cell.forward_cell(&x, None).unwrap();
2677        let data = h.data().unwrap();
2678        assert!(data.iter().any(|&v| v.abs() > 1e-10));
2679    }
2680
2681    #[test]
2682    fn test_rnn_cell_forward_bad_input_ndim() {
2683        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2684        let x = ferrotorch_core::zeros::<f32>(&[1, 2, 4]).unwrap();
2685        assert!(cell.forward_cell(&x, None).is_err());
2686    }
2687
2688    #[test]
2689    fn test_rnn_cell_forward_bad_input_size() {
2690        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2691        let x = ferrotorch_core::zeros::<f32>(&[1, 7]).unwrap();
2692        assert!(cell.forward_cell(&x, None).is_err());
2693    }
2694
2695    #[test]
2696    fn test_rnn_cell_forward_bad_h_shape() {
2697        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2698        let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2699        let h0 = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap(); // wrong batch
2700        assert!(cell.forward_cell(&x, Some(&h0)).is_err());
2701    }
2702
2703    #[test]
2704    fn test_rnn_cell_module_forward() {
2705        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2706        let x = ferrotorch_core::randn::<f32>(&[2, 4]).unwrap();
2707        let h = <RNNCell<f32> as Module<f32>>::forward(&cell, &x).unwrap();
2708        assert_eq!(h.shape(), &[2, 8]);
2709    }
2710
2711    #[test]
2712    fn test_rnn_cell_named_parameters() {
2713        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2714        let named = cell.named_parameters();
2715        assert_eq!(named.len(), 4);
2716        assert_eq!(named[0].0, "weight_ih");
2717        assert_eq!(named[1].0, "weight_hh");
2718        assert_eq!(named[2].0, "bias_ih");
2719        assert_eq!(named[3].0, "bias_hh");
2720    }
2721
2722    #[test]
2723    fn test_rnn_cell_train_eval() {
2724        let mut cell = RNNCell::<f32>::new(4, 8).unwrap();
2725        assert!(cell.is_training());
2726        cell.eval();
2727        assert!(!cell.is_training());
2728        cell.train();
2729        assert!(cell.is_training());
2730    }
2731
2732    #[test]
2733    fn test_rnn_cell_deterministic() {
2734        let cell = RNNCell::<f32>::new(4, 8).unwrap();
2735        let x = ferrotorch_core::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &[1, 4]).unwrap();
2736        let h1 = cell.forward_cell(&x, None).unwrap();
2737        let h2 = cell.forward_cell(&x, None).unwrap();
2738        let d1 = h1.data().unwrap();
2739        let d2 = h2.data().unwrap();
2740        for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
2741            assert!((a - b).abs() < 1e-6, "mismatch at {i}: {a} vs {b}");
2742        }
2743    }
2744
2745    #[test]
2746    fn test_rnn_cell_relu_output_nonneg() {
2747        // With relu nonlinearity, all outputs should be >= 0.
2748        let cell = RNNCell::<f32>::with_nonlinearity(4, 8, RNNNonlinearity::ReLU).unwrap();
2749        let x = ferrotorch_core::randn::<f32>(&[5, 4]).unwrap();
2750        let h = cell.forward_cell(&x, None).unwrap();
2751        let data = h.data().unwrap();
2752        assert!(
2753            data.iter().all(|&v| v >= 0.0),
2754            "relu output should be non-negative"
2755        );
2756    }
2757
2758    #[test]
2759    fn test_rnn_cell_is_send_sync() {
2760        fn assert_send_sync<T: Send + Sync>() {}
2761        assert_send_sync::<RNNCell<f32>>();
2762        assert_send_sync::<RNNCell<f64>>();
2763    }
2764
2765    // =======================================================================
2766    // LSTMCell tests
2767    // =======================================================================
2768
2769    #[test]
2770    fn test_lstm_cell_new_basic() {
2771        let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2772        assert_eq!(cell.input_size(), 10);
2773        assert_eq!(cell.hidden_size(), 20);
2774    }
2775
2776    #[test]
2777    fn test_lstm_cell_invalid_sizes() {
2778        assert!(LSTMCell::<f32>::new(0, 20).is_err());
2779        assert!(LSTMCell::<f32>::new(10, 0).is_err());
2780    }
2781
2782    #[test]
2783    fn test_lstm_cell_parameter_shapes() {
2784        let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2785        let params = cell.parameters();
2786        assert_eq!(params.len(), 4);
2787        assert_eq!(params[0].shape(), &[80, 10]); // weight_ih [4*hs, input]
2788        assert_eq!(params[1].shape(), &[80, 20]); // weight_hh [4*hs, hs]
2789        assert_eq!(params[2].shape(), &[80]); // bias_ih
2790        assert_eq!(params[3].shape(), &[80]); // bias_hh
2791    }
2792
2793    #[test]
2794    fn test_lstm_cell_forward_output_shape() {
2795        let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2796        let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2797        let (h, c) = cell.forward_cell(&x, None).unwrap();
2798        assert_eq!(h.shape(), &[3, 20]);
2799        assert_eq!(c.shape(), &[3, 20]);
2800    }
2801
2802    #[test]
2803    fn test_lstm_cell_forward_with_state() {
2804        let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2805        let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2806        let h0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2807        let c0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2808        let (h, c) = cell.forward_cell(&x, Some((&h0, &c0))).unwrap();
2809        assert_eq!(h.shape(), &[3, 20]);
2810        assert_eq!(c.shape(), &[3, 20]);
2811    }
2812
2813    #[test]
2814    fn test_lstm_cell_forward_nonzero() {
2815        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2816        let x = ferrotorch_core::randn::<f32>(&[1, 4]).unwrap();
2817        let (h, c) = cell.forward_cell(&x, None).unwrap();
2818        let hd = h.data().unwrap();
2819        let cd = c.data().unwrap();
2820        assert!(hd.iter().any(|&v| v.abs() > 1e-10));
2821        assert!(cd.iter().any(|&v| v.abs() > 1e-10));
2822    }
2823
2824    #[test]
2825    fn test_lstm_cell_forward_bad_input_ndim() {
2826        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2827        let x = ferrotorch_core::zeros::<f32>(&[1, 2, 4]).unwrap();
2828        assert!(cell.forward_cell(&x, None).is_err());
2829    }
2830
2831    #[test]
2832    fn test_lstm_cell_forward_bad_input_size() {
2833        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2834        let x = ferrotorch_core::zeros::<f32>(&[1, 7]).unwrap();
2835        assert!(cell.forward_cell(&x, None).is_err());
2836    }
2837
2838    #[test]
2839    fn test_lstm_cell_forward_bad_h_shape() {
2840        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2841        let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2842        let h0 = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap(); // wrong batch
2843        let c0 = ferrotorch_core::zeros::<f32>(&[2, 8]).unwrap();
2844        assert!(cell.forward_cell(&x, Some((&h0, &c0))).is_err());
2845    }
2846
2847    #[test]
2848    fn test_lstm_cell_forward_bad_c_shape() {
2849        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2850        let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2851        let h0 = ferrotorch_core::zeros::<f32>(&[2, 8]).unwrap();
2852        let c0 = ferrotorch_core::zeros::<f32>(&[2, 99]).unwrap(); // wrong hs
2853        assert!(cell.forward_cell(&x, Some((&h0, &c0))).is_err());
2854    }
2855
2856    #[test]
2857    fn test_lstm_cell_module_forward() {
2858        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2859        let x = ferrotorch_core::randn::<f32>(&[2, 4]).unwrap();
2860        // Module::forward returns h only.
2861        let h = <LSTMCell<f32> as Module<f32>>::forward(&cell, &x).unwrap();
2862        assert_eq!(h.shape(), &[2, 8]);
2863    }
2864
2865    #[test]
2866    fn test_lstm_cell_named_parameters() {
2867        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2868        let named = cell.named_parameters();
2869        assert_eq!(named.len(), 4);
2870        assert_eq!(named[0].0, "weight_ih");
2871        assert_eq!(named[1].0, "weight_hh");
2872        assert_eq!(named[2].0, "bias_ih");
2873        assert_eq!(named[3].0, "bias_hh");
2874    }
2875
2876    #[test]
2877    fn test_lstm_cell_train_eval() {
2878        let mut cell = LSTMCell::<f32>::new(4, 8).unwrap();
2879        assert!(cell.is_training());
2880        cell.eval();
2881        assert!(!cell.is_training());
2882        cell.train();
2883        assert!(cell.is_training());
2884    }
2885
2886    #[test]
2887    fn test_lstm_cell_deterministic() {
2888        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2889        let x = ferrotorch_core::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &[1, 4]).unwrap();
2890        let (h1, c1) = cell.forward_cell(&x, None).unwrap();
2891        let (h2, c2) = cell.forward_cell(&x, None).unwrap();
2892        let hd1 = h1.data().unwrap();
2893        let hd2 = h2.data().unwrap();
2894        let cd1 = c1.data().unwrap();
2895        let cd2 = c2.data().unwrap();
2896        for (i, (&a, &b)) in hd1.iter().zip(hd2.iter()).enumerate() {
2897            assert!((a - b).abs() < 1e-6, "h mismatch at {i}: {a} vs {b}");
2898        }
2899        for (i, (&a, &b)) in cd1.iter().zip(cd2.iter()).enumerate() {
2900            assert!((a - b).abs() < 1e-6, "c mismatch at {i}: {a} vs {b}");
2901        }
2902    }
2903
2904    #[test]
2905    fn test_lstm_cell_h_bounded_by_tanh() {
2906        // h = o * tanh(c), so |h| <= 1 always.
2907        let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2908        let x = ferrotorch_core::randn::<f32>(&[10, 4]).unwrap();
2909        let (h, _c) = cell.forward_cell(&x, None).unwrap();
2910        let data = h.data().unwrap();
2911        assert!(
2912            data.iter().all(|&v| v.abs() <= 1.0 + 1e-6),
2913            "LSTM cell h should be bounded by [-1, 1]"
2914        );
2915    }
2916
2917    #[test]
2918    fn test_lstm_cell_is_send_sync() {
2919        fn assert_send_sync<T: Send + Sync>() {}
2920        assert_send_sync::<LSTMCell<f32>>();
2921        assert_send_sync::<LSTMCell<f64>>();
2922    }
2923
2924    // =======================================================================
2925    // GRUCell tests
2926    // =======================================================================
2927
2928    #[test]
2929    fn test_gru_cell_new_basic() {
2930        let cell = GRUCell::<f32>::new(10, 20).unwrap();
2931        assert_eq!(cell.input_size(), 10);
2932        assert_eq!(cell.hidden_size(), 20);
2933    }
2934
2935    #[test]
2936    fn test_gru_cell_invalid_sizes() {
2937        assert!(GRUCell::<f32>::new(0, 20).is_err());
2938        assert!(GRUCell::<f32>::new(10, 0).is_err());
2939    }
2940
2941    #[test]
2942    fn test_gru_cell_parameter_shapes() {
2943        let cell = GRUCell::<f32>::new(10, 20).unwrap();
2944        let params = cell.parameters();
2945        assert_eq!(params.len(), 4);
2946        assert_eq!(params[0].shape(), &[60, 10]); // weight_ih [3*hs, input]
2947        assert_eq!(params[1].shape(), &[60, 20]); // weight_hh [3*hs, hs]
2948        assert_eq!(params[2].shape(), &[60]); // bias_ih
2949        assert_eq!(params[3].shape(), &[60]); // bias_hh
2950    }
2951
2952    #[test]
2953    fn test_gru_cell_forward_output_shape() {
2954        let cell = GRUCell::<f32>::new(10, 20).unwrap();
2955        let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2956        let h = cell.forward_cell(&x, None).unwrap();
2957        assert_eq!(h.shape(), &[3, 20]);
2958    }
2959
2960    #[test]
2961    fn test_gru_cell_forward_with_hidden() {
2962        let cell = GRUCell::<f32>::new(10, 20).unwrap();
2963        let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2964        let h0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2965        let h = cell.forward_cell(&x, Some(&h0)).unwrap();
2966        assert_eq!(h.shape(), &[3, 20]);
2967    }
2968
2969    #[test]
2970    fn test_gru_cell_forward_nonzero() {
2971        let cell = GRUCell::<f32>::new(4, 8).unwrap();
2972        let x = ferrotorch_core::randn::<f32>(&[1, 4]).unwrap();
2973        let h = cell.forward_cell(&x, None).unwrap();
2974        let data = h.data().unwrap();
2975        assert!(data.iter().any(|&v| v.abs() > 1e-10));
2976    }
2977
2978    #[test]
2979    fn test_gru_cell_forward_bad_input_ndim() {
2980        let cell = GRUCell::<f32>::new(4, 8).unwrap();
2981        let x = ferrotorch_core::zeros::<f32>(&[1, 2, 4]).unwrap();
2982        assert!(cell.forward_cell(&x, None).is_err());
2983    }
2984
2985    #[test]
2986    fn test_gru_cell_forward_bad_input_size() {
2987        let cell = GRUCell::<f32>::new(4, 8).unwrap();
2988        let x = ferrotorch_core::zeros::<f32>(&[1, 7]).unwrap();
2989        assert!(cell.forward_cell(&x, None).is_err());
2990    }
2991
2992    #[test]
2993    fn test_gru_cell_forward_bad_h_shape() {
2994        let cell = GRUCell::<f32>::new(4, 8).unwrap();
2995        let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2996        let h0 = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap(); // wrong batch
2997        assert!(cell.forward_cell(&x, Some(&h0)).is_err());
2998    }
2999
3000    #[test]
3001    fn test_gru_cell_module_forward() {
3002        let cell = GRUCell::<f32>::new(4, 8).unwrap();
3003        let x = ferrotorch_core::randn::<f32>(&[2, 4]).unwrap();
3004        let h = <GRUCell<f32> as Module<f32>>::forward(&cell, &x).unwrap();
3005        assert_eq!(h.shape(), &[2, 8]);
3006    }
3007
3008    #[test]
3009    fn test_gru_cell_named_parameters() {
3010        let cell = GRUCell::<f32>::new(4, 8).unwrap();
3011        let named = cell.named_parameters();
3012        assert_eq!(named.len(), 4);
3013        assert_eq!(named[0].0, "weight_ih");
3014        assert_eq!(named[1].0, "weight_hh");
3015        assert_eq!(named[2].0, "bias_ih");
3016        assert_eq!(named[3].0, "bias_hh");
3017    }
3018
3019    #[test]
3020    fn test_gru_cell_train_eval() {
3021        let mut cell = GRUCell::<f32>::new(4, 8).unwrap();
3022        assert!(cell.is_training());
3023        cell.eval();
3024        assert!(!cell.is_training());
3025        cell.train();
3026        assert!(cell.is_training());
3027    }
3028
3029    #[test]
3030    fn test_gru_cell_deterministic() {
3031        let cell = GRUCell::<f32>::new(4, 8).unwrap();
3032        let x = ferrotorch_core::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &[1, 4]).unwrap();
3033        let h1 = cell.forward_cell(&x, None).unwrap();
3034        let h2 = cell.forward_cell(&x, None).unwrap();
3035        let d1 = h1.data().unwrap();
3036        let d2 = h2.data().unwrap();
3037        for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
3038            assert!((a - b).abs() < 1e-6, "mismatch at {i}: {a} vs {b}");
3039        }
3040    }
3041
3042    #[test]
3043    fn test_gru_cell_is_send_sync() {
3044        fn assert_send_sync<T: Send + Sync>() {}
3045        assert_send_sync::<GRUCell<f32>>();
3046        assert_send_sync::<GRUCell<f64>>();
3047    }
3048
3049    // =======================================================================
3050    // RNN (multi-layer) tests
3051    // =======================================================================
3052
3053    #[test]
3054    fn test_rnn_new_basic() {
3055        let rnn = RNN::<f32>::new(10, 20).unwrap();
3056        assert_eq!(rnn.input_size(), 10);
3057        assert_eq!(rnn.hidden_size(), 20);
3058        assert_eq!(rnn.num_layers(), 1);
3059        assert_eq!(rnn.nonlinearity(), RNNNonlinearity::Tanh);
3060    }
3061
3062    #[test]
3063    fn test_rnn_with_options() {
3064        let rnn = RNN::<f32>::with_options(10, 20, 3, RNNNonlinearity::ReLU).unwrap();
3065        assert_eq!(rnn.num_layers(), 3);
3066        assert_eq!(rnn.nonlinearity(), RNNNonlinearity::ReLU);
3067    }
3068
3069    #[test]
3070    fn test_rnn_invalid_sizes() {
3071        assert!(RNN::<f32>::with_options(0, 20, 1, RNNNonlinearity::Tanh).is_err());
3072        assert!(RNN::<f32>::with_options(10, 0, 1, RNNNonlinearity::Tanh).is_err());
3073        assert!(RNN::<f32>::with_options(10, 20, 0, RNNNonlinearity::Tanh).is_err());
3074    }
3075
3076    #[test]
3077    fn test_rnn_parameter_count() {
3078        let rnn = RNN::<f32>::with_options(10, 20, 2, RNNNonlinearity::Tanh).unwrap();
3079        let params = rnn.parameters();
3080        assert_eq!(params.len(), 8); // 4 per layer * 2 layers
3081    }
3082
3083    #[test]
3084    fn test_rnn_parameter_shapes() {
3085        let rnn = RNN::<f32>::new(10, 20).unwrap();
3086        let params = rnn.parameters();
3087        assert_eq!(params[0].shape(), &[20, 10]); // weight_ih [hs, input]
3088        assert_eq!(params[1].shape(), &[20, 20]); // weight_hh [hs, hs]
3089        assert_eq!(params[2].shape(), &[20]); // bias_ih
3090        assert_eq!(params[3].shape(), &[20]); // bias_hh
3091    }
3092
3093    #[test]
3094    fn test_rnn_multi_layer_weight_shapes() {
3095        let rnn = RNN::<f32>::with_options(10, 20, 3, RNNNonlinearity::Tanh).unwrap();
3096        let params = rnn.parameters();
3097
3098        // Layer 0: weight_ih [20, 10] (input_size=10)
3099        assert_eq!(params[0].shape(), &[20, 10]);
3100        // Layer 0: weight_hh [20, 20]
3101        assert_eq!(params[1].shape(), &[20, 20]);
3102
3103        // Layer 1: weight_ih [20, 20] (input_size=hidden_size=20)
3104        assert_eq!(params[4].shape(), &[20, 20]);
3105        // Layer 1: weight_hh [20, 20]
3106        assert_eq!(params[5].shape(), &[20, 20]);
3107
3108        // Layer 2: weight_ih [20, 20]
3109        assert_eq!(params[8].shape(), &[20, 20]);
3110    }
3111
3112    #[test]
3113    fn test_rnn_forward_output_shape() {
3114        let rnn = RNN::<f32>::new(10, 20).unwrap();
3115        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
3116
3117        let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3118
3119        assert_eq!(output.shape(), &[2, 5, 20]); // [B, T, hidden]
3120        assert_eq!(h_n.shape(), &[1, 2, 20]); // [layers, B, hidden]
3121    }
3122
3123    #[test]
3124    fn test_rnn_forward_multi_layer_shapes() {
3125        let rnn = RNN::<f32>::with_options(8, 16, 3, RNNNonlinearity::Tanh).unwrap();
3126        let input = ferrotorch_core::zeros::<f32>(&[4, 7, 8]).unwrap();
3127
3128        let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3129
3130        assert_eq!(output.shape(), &[4, 7, 16]);
3131        assert_eq!(h_n.shape(), &[3, 4, 16]);
3132    }
3133
3134    #[test]
3135    fn test_rnn_module_forward_shape() {
3136        let rnn = RNN::<f32>::new(10, 20).unwrap();
3137        let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
3138        let output = <RNN<f32> as Module<f32>>::forward(&rnn, &input).unwrap();
3139        assert_eq!(output.shape(), &[2, 5, 20]);
3140    }
3141
3142    #[test]
3143    fn test_rnn_forward_does_not_error() {
3144        let rnn = RNN::<f32>::with_options(4, 8, 2, RNNNonlinearity::Tanh).unwrap();
3145        let input = ferrotorch_core::randn::<f32>(&[3, 10, 4]).unwrap();
3146        let result = rnn.forward_with_state(&input, None);
3147        assert!(
3148            result.is_ok(),
3149            "forward should not error: {:?}",
3150            result.err()
3151        );
3152    }
3153
3154    #[test]
3155    fn test_rnn_forward_nonzero_output() {
3156        let rnn = RNN::<f32>::new(4, 8).unwrap();
3157        let input = ferrotorch_core::randn::<f32>(&[1, 3, 4]).unwrap();
3158        let (output, _) = rnn.forward_with_state(&input, None).unwrap();
3159        let data = output.data().unwrap();
3160        assert!(data.iter().any(|&v| v.abs() > 1e-10));
3161    }
3162
3163    #[test]
3164    fn test_rnn_forward_seq_len_1() {
3165        let rnn = RNN::<f32>::new(4, 8).unwrap();
3166        let input = ferrotorch_core::zeros::<f32>(&[1, 1, 4]).unwrap();
3167        let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3168        assert_eq!(output.shape(), &[1, 1, 8]);
3169        assert_eq!(h_n.shape(), &[1, 1, 8]);
3170    }
3171
3172    #[test]
3173    fn test_rnn_forward_with_initial_state() {
3174        let rnn = RNN::<f32>::new(4, 8).unwrap();
3175        let h0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
3176        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
3177        let result = rnn.forward_with_state(&input, Some(&h0));
3178        assert!(result.is_ok());
3179    }
3180
3181    #[test]
3182    fn test_rnn_forward_state_shape_mismatch() {
3183        let rnn = RNN::<f32>::new(4, 8).unwrap();
3184        let h0 = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap(); // wrong batch
3185        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
3186        assert!(rnn.forward_with_state(&input, Some(&h0)).is_err());
3187    }
3188
3189    #[test]
3190    fn test_rnn_forward_input_wrong_ndim() {
3191        let rnn = RNN::<f32>::new(4, 8).unwrap();
3192        let input = ferrotorch_core::zeros::<f32>(&[10, 4]).unwrap();
3193        assert!(rnn.forward_with_state(&input, None).is_err());
3194    }
3195
3196    #[test]
3197    fn test_rnn_forward_input_size_mismatch() {
3198        let rnn = RNN::<f32>::new(4, 8).unwrap();
3199        let input = ferrotorch_core::zeros::<f32>(&[1, 5, 7]).unwrap();
3200        assert!(rnn.forward_with_state(&input, None).is_err());
3201    }
3202
3203    #[test]
3204    fn test_rnn_named_parameters() {
3205        let rnn = RNN::<f32>::with_options(4, 8, 2, RNNNonlinearity::Tanh).unwrap();
3206        let named = rnn.named_parameters();
3207        assert_eq!(named.len(), 8);
3208        assert_eq!(named[0].0, "layers.0.weight_ih");
3209        assert_eq!(named[1].0, "layers.0.weight_hh");
3210        assert_eq!(named[2].0, "layers.0.bias_ih");
3211        assert_eq!(named[3].0, "layers.0.bias_hh");
3212        assert_eq!(named[4].0, "layers.1.weight_ih");
3213        assert_eq!(named[5].0, "layers.1.weight_hh");
3214        assert_eq!(named[6].0, "layers.1.bias_ih");
3215        assert_eq!(named[7].0, "layers.1.bias_hh");
3216    }
3217
3218    #[test]
3219    fn test_rnn_train_eval() {
3220        let mut rnn = RNN::<f32>::new(4, 8).unwrap();
3221        assert!(rnn.is_training());
3222        rnn.eval();
3223        assert!(!rnn.is_training());
3224        rnn.train();
3225        assert!(rnn.is_training());
3226    }
3227
3228    #[test]
3229    fn test_rnn_all_parameters_require_grad() {
3230        let rnn = RNN::<f32>::new(4, 8).unwrap();
3231        for param in rnn.parameters() {
3232            assert!(param.requires_grad());
3233        }
3234    }
3235
3236    #[test]
3237    fn test_rnn_deterministic() {
3238        let rnn = RNN::<f32>::new(4, 8).unwrap();
3239        let input = ferrotorch_core::from_slice::<f32>(
3240            &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
3241            &[1, 2, 4],
3242        )
3243        .unwrap();
3244
3245        let (out1, _) = rnn.forward_with_state(&input, None).unwrap();
3246        let (out2, _) = rnn.forward_with_state(&input, None).unwrap();
3247
3248        let d1 = out1.data().unwrap();
3249        let d2 = out2.data().unwrap();
3250        for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
3251            assert!(
3252                (a - b).abs() < 1e-6,
3253                "output mismatch at index {i}: {a} vs {b}"
3254            );
3255        }
3256    }
3257
3258    #[test]
3259    fn test_rnn_state_dict_roundtrip() {
3260        let rnn = RNN::<f32>::new(4, 8).unwrap();
3261        let sd = rnn.state_dict();
3262        assert_eq!(sd.len(), 4);
3263        assert!(sd.contains_key("layers.0.weight_ih"));
3264
3265        let mut rnn2 = RNN::<f32>::new(4, 8).unwrap();
3266        rnn2.load_state_dict(&sd, true).unwrap();
3267    }
3268
3269    #[test]
3270    fn test_rnn_relu_forward() {
3271        let rnn = RNN::<f32>::with_options(4, 8, 1, RNNNonlinearity::ReLU).unwrap();
3272        let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
3273        let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3274        assert_eq!(output.shape(), &[2, 3, 8]);
3275        assert_eq!(h_n.shape(), &[1, 2, 8]);
3276    }
3277
3278    #[test]
3279    fn test_rnn_is_send_sync() {
3280        fn assert_send_sync<T: Send + Sync>() {}
3281        assert_send_sync::<RNN<f32>>();
3282        assert_send_sync::<RNN<f64>>();
3283    }
3284}