Skip to main content

axonml_nn/layers/
rnn.rs

1//! Recurrent Neural Network Layers - RNN, LSTM, GRU
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/rnn.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::init::{xavier_uniform, zeros};
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25// =============================================================================
26// RNNCell
27// =============================================================================
28
29/// A single RNN cell.
30///
31/// h' = tanh(W_ih * x + b_ih + W_hh * h + b_hh)
32pub struct RNNCell {
33    /// Input-hidden weights.
34    pub weight_ih: Parameter,
35    /// Hidden-hidden weights.
36    pub weight_hh: Parameter,
37    /// Input-hidden bias.
38    pub bias_ih: Parameter,
39    /// Hidden-hidden bias.
40    pub bias_hh: Parameter,
41    /// Input size.
42    input_size: usize,
43    /// Hidden size.
44    hidden_size: usize,
45}
46
47impl RNNCell {
48    /// Creates a new RNNCell.
49    pub fn new(input_size: usize, hidden_size: usize) -> Self {
50        Self {
51            weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
52            weight_hh: Parameter::named(
53                "weight_hh",
54                xavier_uniform(hidden_size, hidden_size),
55                true,
56            ),
57            bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
58            bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
59            input_size,
60            hidden_size,
61        }
62    }
63
64    /// Returns the expected input size.
65    pub fn input_size(&self) -> usize {
66        self.input_size
67    }
68
69    /// Returns the hidden state size.
70    pub fn hidden_size(&self) -> usize {
71        self.hidden_size
72    }
73
74    /// Forward pass for a single time step.
75    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
76        let input_features = input.data().shape().last().copied().unwrap_or(0);
77        assert_eq!(
78            input_features, self.input_size,
79            "RNNCell: expected input size {}, got {}",
80            self.input_size, input_features
81        );
82        // x @ W_ih^T + b_ih
83        let weight_ih = self.weight_ih.variable();
84        let weight_ih_t = weight_ih.transpose(0, 1);
85        let ih = input.matmul(&weight_ih_t);
86        let bias_ih = self.bias_ih.variable();
87        let ih = ih.add_var(&bias_ih);
88
89        // h @ W_hh^T + b_hh
90        let weight_hh = self.weight_hh.variable();
91        let weight_hh_t = weight_hh.transpose(0, 1);
92        let hh = hidden.matmul(&weight_hh_t);
93        let bias_hh = self.bias_hh.variable();
94        let hh = hh.add_var(&bias_hh);
95
96        // tanh(ih + hh)
97        ih.add_var(&hh).tanh()
98    }
99}
100
101impl Module for RNNCell {
102    fn forward(&self, input: &Variable) -> Variable {
103        // Initialize hidden state to zeros
104        let batch_size = input.shape()[0];
105        let hidden = Variable::new(
106            zeros(&[batch_size, self.hidden_size]),
107            input.requires_grad(),
108        );
109        self.forward_step(input, &hidden)
110    }
111
112    fn parameters(&self) -> Vec<Parameter> {
113        vec![
114            self.weight_ih.clone(),
115            self.weight_hh.clone(),
116            self.bias_ih.clone(),
117            self.bias_hh.clone(),
118        ]
119    }
120
121    fn named_parameters(&self) -> HashMap<String, Parameter> {
122        let mut params = HashMap::new();
123        params.insert("weight_ih".to_string(), self.weight_ih.clone());
124        params.insert("weight_hh".to_string(), self.weight_hh.clone());
125        params.insert("bias_ih".to_string(), self.bias_ih.clone());
126        params.insert("bias_hh".to_string(), self.bias_hh.clone());
127        params
128    }
129
130    fn name(&self) -> &'static str {
131        "RNNCell"
132    }
133}
134
135// =============================================================================
136// RNN
137// =============================================================================
138
139/// Multi-layer RNN.
140///
141/// Processes sequences through stacked RNN layers.
142pub struct RNN {
143    /// RNN cells for each layer.
144    cells: Vec<RNNCell>,
145    /// Input size.
146    _input_size: usize,
147    /// Hidden size.
148    hidden_size: usize,
149    /// Number of layers.
150    num_layers: usize,
151    /// Batch first flag.
152    batch_first: bool,
153}
154
155impl RNN {
156    /// Creates a new multi-layer RNN.
157    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
158        Self::with_options(input_size, hidden_size, num_layers, true)
159    }
160
161    /// Creates an RNN with all options.
162    pub fn with_options(
163        input_size: usize,
164        hidden_size: usize,
165        num_layers: usize,
166        batch_first: bool,
167    ) -> Self {
168        let mut cells = Vec::with_capacity(num_layers);
169
170        // First layer takes input_size
171        cells.push(RNNCell::new(input_size, hidden_size));
172
173        // Subsequent layers take hidden_size
174        for _ in 1..num_layers {
175            cells.push(RNNCell::new(hidden_size, hidden_size));
176        }
177
178        Self {
179            cells,
180            _input_size: input_size,
181            hidden_size,
182            num_layers,
183            batch_first,
184        }
185    }
186}
187
188impl Module for RNN {
189    fn forward(&self, input: &Variable) -> Variable {
190        let shape = input.shape();
191        let (batch_size, seq_len, input_features) = if self.batch_first {
192            (shape[0], shape[1], shape[2])
193        } else {
194            (shape[1], shape[0], shape[2])
195        };
196
197        // Initialize hidden states
198        let mut hiddens: Vec<Variable> = (0..self.num_layers)
199            .map(|_| {
200                Variable::new(
201                    zeros(&[batch_size, self.hidden_size]),
202                    input.requires_grad(),
203                )
204            })
205            .collect();
206
207        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
208        let cell0 = &self.cells[0];
209        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
210        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
211        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
212        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, self.hidden_size]);
213
214        // Hoist weight transposes out of the per-timestep loop
215        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
216        let bias_hh_0 = cell0.bias_hh.variable();
217
218        let mut outputs = Vec::with_capacity(seq_len);
219
220        for t in 0..seq_len {
221            // Layer 0: use pre-computed ih projection + hoisted weight transpose
222            let ih_t = ih_all_3d.select(1, t);
223            let hh = hiddens[0].matmul(&w_hh_t_0).add_var(&bias_hh_0);
224            hiddens[0] = ih_t.add_var(&hh).tanh();
225
226            // Subsequent layers
227            for l in 1..self.num_layers {
228                let layer_input = hiddens[l - 1].clone();
229                hiddens[l] = self.cells[l].forward_step(&layer_input, &hiddens[l]);
230            }
231
232            outputs.push(hiddens[self.num_layers - 1].clone());
233        }
234
235        // Stack outputs using graph-tracked cat (unsqueeze + cat along time dim)
236        let time_dim = usize::from(self.batch_first);
237        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
238        let refs: Vec<&Variable> = unsqueezed.iter().collect();
239        Variable::cat(&refs, time_dim)
240    }
241
242    fn parameters(&self) -> Vec<Parameter> {
243        self.cells.iter().flat_map(|c| c.parameters()).collect()
244    }
245
246    fn name(&self) -> &'static str {
247        "RNN"
248    }
249}
250
251// =============================================================================
252// LSTMCell
253// =============================================================================
254
255/// A single LSTM cell.
256pub struct LSTMCell {
257    /// Input-hidden weights for all gates.
258    pub weight_ih: Parameter,
259    /// Hidden-hidden weights for all gates.
260    pub weight_hh: Parameter,
261    /// Input-hidden bias for all gates.
262    pub bias_ih: Parameter,
263    /// Hidden-hidden bias for all gates.
264    pub bias_hh: Parameter,
265    /// Input size.
266    input_size: usize,
267    /// Hidden size.
268    hidden_size: usize,
269}
270
271impl LSTMCell {
272    /// Creates a new LSTMCell.
273    pub fn new(input_size: usize, hidden_size: usize) -> Self {
274        // LSTM has 4 gates, so weight size is 4*hidden_size
275        Self {
276            weight_ih: Parameter::named(
277                "weight_ih",
278                xavier_uniform(input_size, 4 * hidden_size),
279                true,
280            ),
281            weight_hh: Parameter::named(
282                "weight_hh",
283                xavier_uniform(hidden_size, 4 * hidden_size),
284                true,
285            ),
286            bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
287            bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
288            input_size,
289            hidden_size,
290        }
291    }
292
293    /// Returns the expected input size.
294    pub fn input_size(&self) -> usize {
295        self.input_size
296    }
297
298    /// Returns the hidden state size.
299    pub fn hidden_size(&self) -> usize {
300        self.hidden_size
301    }
302
303    /// Forward pass returning (h', c').
304    pub fn forward_step(
305        &self,
306        input: &Variable,
307        hx: &(Variable, Variable),
308    ) -> (Variable, Variable) {
309        let input_features = input.data().shape().last().copied().unwrap_or(0);
310        assert_eq!(
311            input_features, self.input_size,
312            "LSTMCell: expected input size {}, got {}",
313            self.input_size, input_features
314        );
315
316        let (h, c) = hx;
317
318        // Compute all gates at once (x @ W^T + b)
319        let weight_ih = self.weight_ih.variable();
320        let weight_ih_t = weight_ih.transpose(0, 1);
321        let ih = input.matmul(&weight_ih_t);
322        let bias_ih = self.bias_ih.variable();
323        let ih = ih.add_var(&bias_ih);
324
325        let weight_hh = self.weight_hh.variable();
326        let weight_hh_t = weight_hh.transpose(0, 1);
327        let hh = h.matmul(&weight_hh_t);
328        let bias_hh = self.bias_hh.variable();
329        let hh = hh.add_var(&bias_hh);
330
331        let gates = ih.add_var(&hh);
332        let hs = self.hidden_size;
333
334        // Split into 4 gates using narrow (preserves gradient flow)
335        let i = gates.narrow(1, 0, hs).sigmoid();
336        let f = gates.narrow(1, hs, hs).sigmoid();
337        let g = gates.narrow(1, 2 * hs, hs).tanh();
338        let o = gates.narrow(1, 3 * hs, hs).sigmoid();
339
340        // c' = f * c + i * g
341        let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
342
343        // h' = o * tanh(c')
344        let h_new = o.mul_var(&c_new.tanh());
345
346        (h_new, c_new)
347    }
348}
349
350impl Module for LSTMCell {
351    fn forward(&self, input: &Variable) -> Variable {
352        let batch_size = input.shape()[0];
353        let h = Variable::new(
354            zeros(&[batch_size, self.hidden_size]),
355            input.requires_grad(),
356        );
357        let c = Variable::new(
358            zeros(&[batch_size, self.hidden_size]),
359            input.requires_grad(),
360        );
361        let (h_new, _) = self.forward_step(input, &(h, c));
362        h_new
363    }
364
365    fn parameters(&self) -> Vec<Parameter> {
366        vec![
367            self.weight_ih.clone(),
368            self.weight_hh.clone(),
369            self.bias_ih.clone(),
370            self.bias_hh.clone(),
371        ]
372    }
373
374    fn named_parameters(&self) -> HashMap<String, Parameter> {
375        let mut params = HashMap::new();
376        params.insert("weight_ih".to_string(), self.weight_ih.clone());
377        params.insert("weight_hh".to_string(), self.weight_hh.clone());
378        params.insert("bias_ih".to_string(), self.bias_ih.clone());
379        params.insert("bias_hh".to_string(), self.bias_hh.clone());
380        params
381    }
382
383    fn name(&self) -> &'static str {
384        "LSTMCell"
385    }
386}
387
388// =============================================================================
389// LSTM
390// =============================================================================
391
392/// Multi-layer LSTM.
393pub struct LSTM {
394    /// LSTM cells for each layer.
395    cells: Vec<LSTMCell>,
396    /// Input size.
397    input_size: usize,
398    /// Hidden size.
399    hidden_size: usize,
400    /// Number of layers.
401    num_layers: usize,
402    /// Batch first flag.
403    batch_first: bool,
404}
405
406impl LSTM {
407    /// Creates a new multi-layer LSTM.
408    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
409        Self::with_options(input_size, hidden_size, num_layers, true)
410    }
411
412    /// Creates an LSTM with all options.
413    pub fn with_options(
414        input_size: usize,
415        hidden_size: usize,
416        num_layers: usize,
417        batch_first: bool,
418    ) -> Self {
419        let mut cells = Vec::with_capacity(num_layers);
420        cells.push(LSTMCell::new(input_size, hidden_size));
421        for _ in 1..num_layers {
422            cells.push(LSTMCell::new(hidden_size, hidden_size));
423        }
424
425        Self {
426            cells,
427            input_size,
428            hidden_size,
429            num_layers,
430            batch_first,
431        }
432    }
433
434    /// Returns the expected input size.
435    pub fn input_size(&self) -> usize {
436        self.input_size
437    }
438
439    /// Returns the hidden state size.
440    pub fn hidden_size(&self) -> usize {
441        self.hidden_size
442    }
443
444    /// Returns the number of layers.
445    pub fn num_layers(&self) -> usize {
446        self.num_layers
447    }
448}
449
450impl Module for LSTM {
451    fn forward(&self, input: &Variable) -> Variable {
452        let shape = input.shape();
453        let (batch_size, seq_len, input_features) = if self.batch_first {
454            (shape[0], shape[1], shape[2])
455        } else {
456            (shape[1], shape[0], shape[2])
457        };
458
459        let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
460            .map(|_| {
461                (
462                    Variable::new(
463                        zeros(&[batch_size, self.hidden_size]),
464                        input.requires_grad(),
465                    ),
466                    Variable::new(
467                        zeros(&[batch_size, self.hidden_size]),
468                        input.requires_grad(),
469                    ),
470                )
471            })
472            .collect();
473
474        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
475        // input: [batch, seq, features] -> reshaped to [batch*seq, features]
476        // ih_all: [batch*seq, 4*hidden] = input_2d @ W_ih^T + bias_ih
477        // Note: matmul auto-dispatches to cuBLAS GEMM when tensors are on GPU
478        let cell0 = &self.cells[0];
479        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
480        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
481        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
482        // ih_all_3d: [batch, seq, 4*hidden]
483        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 4 * self.hidden_size]);
484
485        // Hoist weight transpose + bias out of the per-timestep loop
486        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
487        let bias_hh_0 = cell0.bias_hh.variable();
488
489        let mut outputs = Vec::with_capacity(seq_len);
490
491        // Check if we're on GPU for fused gate kernel path
492        #[cfg(feature = "cuda")]
493        let on_gpu = input.data().device().is_gpu();
494        #[cfg(not(feature = "cuda"))]
495        let on_gpu = false;
496
497        for t in 0..seq_len {
498            // Layer 0: use pre-computed ih projection + hoisted weight transpose
499            let ih_t = ih_all_3d.select(1, t);
500            let (h, c) = &states[0];
501
502            // h @ W_hh^T + bias_hh (cuBLAS on GPU, matrixmultiply on CPU)
503            let hh = h.matmul(&w_hh_t_0).add_var(&bias_hh_0);
504
505            // Combined gates = ih + hh
506            let gates = ih_t.add_var(&hh);
507
508            if on_gpu {
509                // GPU path: fused LSTM gate kernel (1 launch vs ~14 separate ops)
510                // gates [batch, 4*hidden], c [batch, hidden] → h_new, c_new [batch, hidden]
511                #[cfg(feature = "cuda")]
512                {
513                    let hs = self.hidden_size;
514                    let gates_data = gates.data();
515                    let c_data = c.data();
516
517                    if let Some((h_tensor, c_tensor)) = gates_data.lstm_gates_fused(&c_data, hs) {
518                        // Save forward state for backward
519                        let saved_gates = gates_data.clone();
520                        let saved_c_prev = c_data.clone();
521                        let saved_c_new = c_tensor.clone();
522
523                        // Create proper backward that calls LSTM backward kernel
524                        let backward_fn = axonml_autograd::LstmGatesBackward::new(
525                            gates.grad_fn().cloned(),
526                            c.grad_fn().cloned(),
527                            saved_gates,
528                            saved_c_prev,
529                            saved_c_new,
530                            hs,
531                        );
532                        let grad_fn = axonml_autograd::GradFn::new(backward_fn);
533
534                        let h_new = Variable::from_operation(
535                            h_tensor,
536                            grad_fn.clone(),
537                            input.requires_grad(),
538                        );
539                        let c_new =
540                            Variable::from_operation(c_tensor, grad_fn, input.requires_grad());
541                        states[0] = (h_new, c_new);
542                    }
543                }
544            } else {
545                // CPU path: individual ops (each autograd-tracked)
546                let hs = self.hidden_size;
547                let i_gate = gates.narrow(1, 0, hs).sigmoid();
548                let f_gate = gates.narrow(1, hs, hs).sigmoid();
549                let g_gate = gates.narrow(1, 2 * hs, hs).tanh();
550                let o_gate = gates.narrow(1, 3 * hs, hs).sigmoid();
551                let c_new = f_gate.mul_var(c).add_var(&i_gate.mul_var(&g_gate));
552                let h_new = o_gate.mul_var(&c_new.tanh());
553                states[0] = (h_new, c_new);
554            }
555
556            // Subsequent layers use the regular cell forward_step
557            for l in 1..self.num_layers {
558                let layer_input = states[l - 1].0.clone();
559                states[l] = self.cells[l].forward_step(&layer_input, &states[l]);
560            }
561
562            outputs.push(states[self.num_layers - 1].0.clone());
563        }
564
565        // Stack outputs along the time dimension
566        let time_dim = usize::from(self.batch_first);
567        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
568        let refs: Vec<&Variable> = unsqueezed.iter().collect();
569        Variable::cat(&refs, time_dim)
570    }
571
572    fn parameters(&self) -> Vec<Parameter> {
573        self.cells.iter().flat_map(|c| c.parameters()).collect()
574    }
575
576    fn named_parameters(&self) -> HashMap<String, Parameter> {
577        let mut params = HashMap::new();
578        if self.cells.len() == 1 {
579            // Single layer: expose directly without cell index prefix
580            for (n, p) in self.cells[0].named_parameters() {
581                params.insert(n, p);
582            }
583        } else {
584            for (i, cell) in self.cells.iter().enumerate() {
585                for (n, p) in cell.named_parameters() {
586                    params.insert(format!("cells.{i}.{n}"), p);
587                }
588            }
589        }
590        params
591    }
592
593    fn name(&self) -> &'static str {
594        "LSTM"
595    }
596}
597
598// =============================================================================
599// GRUCell and GRU
600// =============================================================================
601
602/// A single GRU cell.
603///
604/// h' = (1 - z) * n + z * h
605/// where:
606///   r = sigmoid(W_ir * x + b_ir + W_hr * h + b_hr)  (reset gate)
607///   z = sigmoid(W_iz * x + b_iz + W_hz * h + b_hz)  (update gate)
608///   n = tanh(W_in * x + b_in + r * (W_hn * h + b_hn))  (new gate)
609pub struct GRUCell {
610    /// Input-hidden weights for all gates (reset, update, new).
611    pub weight_ih: Parameter,
612    /// Hidden-hidden weights for all gates (reset, update, new).
613    pub weight_hh: Parameter,
614    /// Input-hidden bias for all gates.
615    pub bias_ih: Parameter,
616    /// Hidden-hidden bias for all gates.
617    pub bias_hh: Parameter,
618    /// Input size.
619    input_size: usize,
620    /// Hidden size.
621    hidden_size: usize,
622}
623
624impl GRUCell {
625    /// Creates a new GRU cell.
626    pub fn new(input_size: usize, hidden_size: usize) -> Self {
627        Self {
628            weight_ih: Parameter::named(
629                "weight_ih",
630                xavier_uniform(input_size, 3 * hidden_size),
631                true,
632            ),
633            weight_hh: Parameter::named(
634                "weight_hh",
635                xavier_uniform(hidden_size, 3 * hidden_size),
636                true,
637            ),
638            bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
639            bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
640            input_size,
641            hidden_size,
642        }
643    }
644
645    /// Returns the expected input size.
646    pub fn input_size(&self) -> usize {
647        self.input_size
648    }
649
650    /// Returns the hidden state size.
651    pub fn hidden_size(&self) -> usize {
652        self.hidden_size
653    }
654}
655
656impl GRUCell {
657    /// Forward pass for a single time step with explicit hidden state.
658    ///
659    /// GRU equations:
660    /// r_t = sigmoid(W_ir @ x_t + b_ir + W_hr @ h_{t-1} + b_hr)
661    /// z_t = sigmoid(W_iz @ x_t + b_iz + W_hz @ h_{t-1} + b_hz)
662    /// n_t = tanh(W_in @ x_t + b_in + r_t * (W_hn @ h_{t-1} + b_hn))
663    /// h_t = (1 - z_t) * n_t + z_t * h_{t-1}
664    ///
665    /// All computations use Variable operations for proper gradient flow.
666    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
667        let _batch_size = input.shape()[0];
668        let hidden_size = self.hidden_size;
669
670        // Get weight matrices
671        let weight_ih = self.weight_ih.variable();
672        let weight_hh = self.weight_hh.variable();
673        let bias_ih = self.bias_ih.variable();
674        let bias_hh = self.bias_hh.variable();
675
676        // Compute input transformation: x @ W_ih^T + b_ih
677        // Shape: [batch, 3*hidden_size]
678        let weight_ih_t = weight_ih.transpose(0, 1);
679        let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
680
681        // Compute hidden transformation: h @ W_hh^T + b_hh
682        // Shape: [batch, 3*hidden_size]
683        let weight_hh_t = weight_hh.transpose(0, 1);
684        let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
685
686        // Use narrow to split into gates (preserves gradient flow)
687        // Each gate slice: [batch, hidden_size]
688        let ih_r = ih.narrow(1, 0, hidden_size);
689        let ih_z = ih.narrow(1, hidden_size, hidden_size);
690        let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
691
692        let hh_r = hh.narrow(1, 0, hidden_size);
693        let hh_z = hh.narrow(1, hidden_size, hidden_size);
694        let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
695
696        // Compute gates using Variable operations for gradient flow
697        // r = sigmoid(ih_r + hh_r)
698        let r = ih_r.add_var(&hh_r).sigmoid();
699
700        // z = sigmoid(ih_z + hh_z)
701        let z = ih_z.add_var(&hh_z).sigmoid();
702
703        // n = tanh(ih_n + r * hh_n)
704        let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
705
706        // h_new = (1 - z) * n + z * h_prev
707        // Rewritten as: n + z * (h_prev - n)  to avoid allocating a ones tensor
708        let h_minus_n = hidden.sub_var(&n);
709        n.add_var(&z.mul_var(&h_minus_n))
710    }
711}
712
713impl Module for GRUCell {
714    fn forward(&self, input: &Variable) -> Variable {
715        let batch_size = input.shape()[0];
716
717        // Initialize hidden state to zeros
718        let hidden = Variable::new(
719            zeros(&[batch_size, self.hidden_size]),
720            input.requires_grad(),
721        );
722
723        self.forward_step(input, &hidden)
724    }
725
726    fn parameters(&self) -> Vec<Parameter> {
727        vec![
728            self.weight_ih.clone(),
729            self.weight_hh.clone(),
730            self.bias_ih.clone(),
731            self.bias_hh.clone(),
732        ]
733    }
734
735    fn named_parameters(&self) -> HashMap<String, Parameter> {
736        let mut params = HashMap::new();
737        params.insert("weight_ih".to_string(), self.weight_ih.clone());
738        params.insert("weight_hh".to_string(), self.weight_hh.clone());
739        params.insert("bias_ih".to_string(), self.bias_ih.clone());
740        params.insert("bias_hh".to_string(), self.bias_hh.clone());
741        params
742    }
743
744    fn name(&self) -> &'static str {
745        "GRUCell"
746    }
747}
748
749/// Multi-layer GRU.
750pub struct GRU {
751    /// GRU cells for each layer.
752    cells: Vec<GRUCell>,
753    /// Hidden state size.
754    hidden_size: usize,
755    /// Number of layers.
756    num_layers: usize,
757    /// If true, input is (batch, seq, features), else (seq, batch, features).
758    batch_first: bool,
759}
760
761impl GRU {
762    /// Creates a new multi-layer GRU.
763    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
764        let mut cells = Vec::with_capacity(num_layers);
765        cells.push(GRUCell::new(input_size, hidden_size));
766        for _ in 1..num_layers {
767            cells.push(GRUCell::new(hidden_size, hidden_size));
768        }
769        Self {
770            cells,
771            hidden_size,
772            num_layers,
773            batch_first: true,
774        }
775    }
776
777    /// Returns the hidden state size.
778    pub fn hidden_size(&self) -> usize {
779        self.hidden_size
780    }
781
782    /// Returns the number of layers.
783    pub fn num_layers(&self) -> usize {
784        self.num_layers
785    }
786}
787
788impl Module for GRU {
789    fn forward(&self, input: &Variable) -> Variable {
790        let shape = input.shape();
791        let (batch_size, seq_len, input_features) = if self.batch_first {
792            (shape[0], shape[1], shape[2])
793        } else {
794            (shape[1], shape[0], shape[2])
795        };
796
797        // Initialize hidden states for all layers as Variables (with gradients)
798        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
799            .map(|_| {
800                Variable::new(
801                    zeros(&[batch_size, self.hidden_size]),
802                    input.requires_grad(),
803                )
804            })
805            .collect();
806
807        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
808        // One big matmul instead of seq_len small ones
809        let cell0 = &self.cells[0];
810        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
811        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
812        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
813        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
814
815        // Hoist weight transpose + bias out of the per-timestep loop
816        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
817        let bias_hh_0 = cell0.bias_hh.variable();
818
819        let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
820
821        // Check if we're on GPU for fused gate kernel path
822        #[cfg(feature = "cuda")]
823        let on_gpu = input.data().device().is_gpu();
824        #[cfg(not(feature = "cuda"))]
825        let on_gpu = false;
826
827        for t in 0..seq_len {
828            // Layer 0: use pre-computed ih projection + hoisted weight transpose
829            let ih_t = ih_all_3d.select(1, t);
830            let hidden = &hidden_states[0];
831            let hs = self.hidden_size;
832
833            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
834
835            if on_gpu {
836                // GPU path: fused GRU gate kernel (1 launch vs ~12 separate ops)
837                // ih_t [batch, 3*hidden], hh [batch, 3*hidden], hidden [batch, hidden] → h_new [batch, hidden]
838                #[cfg(feature = "cuda")]
839                {
840                    let ih_data = ih_t.data();
841                    let hh_data = hh.data();
842                    let h_data = hidden.data();
843
844                    if let Some(h_tensor) = ih_data.gru_gates_fused(&hh_data, &h_data, hs) {
845                        // Save forward state for backward
846                        let saved_ih = ih_data.clone();
847                        let saved_hh = hh_data.clone();
848                        let saved_h_prev = h_data.clone();
849
850                        // Create proper backward that calls GRU backward kernel
851                        let backward_fn = axonml_autograd::GruGatesBackward::new(
852                            ih_t.grad_fn().cloned(),
853                            hh.grad_fn().cloned(),
854                            hidden.grad_fn().cloned(),
855                            saved_ih,
856                            saved_hh,
857                            saved_h_prev,
858                            hs,
859                        );
860                        let grad_fn = axonml_autograd::GradFn::new(backward_fn);
861
862                        let h_new =
863                            Variable::from_operation(h_tensor, grad_fn, input.requires_grad());
864                        hidden_states[0] = h_new;
865                    }
866                }
867            } else {
868                // CPU path: individual ops (each autograd-tracked)
869                let ih_r = ih_t.narrow(1, 0, hs);
870                let ih_z = ih_t.narrow(1, hs, hs);
871                let ih_n = ih_t.narrow(1, 2 * hs, hs);
872                let hh_r = hh.narrow(1, 0, hs);
873                let hh_z = hh.narrow(1, hs, hs);
874                let hh_n = hh.narrow(1, 2 * hs, hs);
875
876                let r = ih_r.add_var(&hh_r).sigmoid();
877                let z = ih_z.add_var(&hh_z).sigmoid();
878                let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
879                let h_minus_n = hidden.sub_var(&n);
880                let h_new = n.add_var(&z.mul_var(&h_minus_n));
881                hidden_states[0] = h_new;
882            }
883
884            // Subsequent layers use the regular cell forward_step
885            let mut layer_output = hidden_states[0].clone();
886            for l in 1..self.num_layers {
887                let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
888                hidden_states[l] = new_hidden.clone();
889                layer_output = new_hidden;
890            }
891
892            output_vars.push(layer_output);
893        }
894
895        // Stack outputs along the time dimension
896        self.stack_outputs(&output_vars, batch_size, seq_len)
897    }
898
899    fn parameters(&self) -> Vec<Parameter> {
900        self.cells.iter().flat_map(|c| c.parameters()).collect()
901    }
902
903    fn named_parameters(&self) -> HashMap<String, Parameter> {
904        let mut params = HashMap::new();
905        if self.cells.len() == 1 {
906            for (n, p) in self.cells[0].named_parameters() {
907                params.insert(n, p);
908            }
909        } else {
910            for (i, cell) in self.cells.iter().enumerate() {
911                for (n, p) in cell.named_parameters() {
912                    params.insert(format!("cells.{i}.{n}"), p);
913                }
914            }
915        }
916        params
917    }
918
919    fn name(&self) -> &'static str {
920        "GRU"
921    }
922}
923
924impl GRU {
925    /// Forward pass that returns the mean of all hidden states.
926    /// This is equivalent to processing then mean pooling, but with proper gradient flow.
927    pub fn forward_mean(&self, input: &Variable) -> Variable {
928        let shape = input.shape();
929        let (batch_size, seq_len, input_features) = if self.batch_first {
930            (shape[0], shape[1], shape[2])
931        } else {
932            (shape[1], shape[0], shape[2])
933        };
934
935        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
936            .map(|_| {
937                Variable::new(
938                    zeros(&[batch_size, self.hidden_size]),
939                    input.requires_grad(),
940                )
941            })
942            .collect();
943
944        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
945        let cell0 = &self.cells[0];
946        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
947        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
948        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
949        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
950
951        // Hoist weight transpose + bias out of per-timestep loop
952        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
953        let bias_hh_0 = cell0.bias_hh.variable();
954
955        let mut output_sum: Option<Variable> = None;
956        let hs = self.hidden_size;
957
958        for t in 0..seq_len {
959            // Layer 0: use pre-computed ih projection + hoisted weight transpose
960            let ih_t = ih_all_3d.select(1, t);
961            let hidden = &hidden_states[0];
962            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
963
964            let ih_r = ih_t.narrow(1, 0, hs);
965            let ih_z = ih_t.narrow(1, hs, hs);
966            let ih_n = ih_t.narrow(1, 2 * hs, hs);
967            let hh_r = hh.narrow(1, 0, hs);
968            let hh_z = hh.narrow(1, hs, hs);
969            let hh_n = hh.narrow(1, 2 * hs, hs);
970
971            let r = ih_r.add_var(&hh_r).sigmoid();
972            let z = ih_z.add_var(&hh_z).sigmoid();
973            let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
974            let h_minus_n = hidden.sub_var(&n);
975            let h_new = n.add_var(&z.mul_var(&h_minus_n));
976            hidden_states[0] = h_new.clone();
977
978            // Subsequent layers
979            let mut layer_output = h_new;
980            for l in 1..self.num_layers {
981                let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
982                hidden_states[l] = new_hidden.clone();
983                layer_output = new_hidden;
984            }
985
986            output_sum = Some(match output_sum {
987                None => layer_output,
988                Some(acc) => acc.add_var(&layer_output),
989            });
990        }
991
992        match output_sum {
993            Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
994            None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
995        }
996    }
997
998    /// Forward pass that returns the last hidden state.
999    /// Good for sequence classification with proper gradient flow.
1000    pub fn forward_last(&self, input: &Variable) -> Variable {
1001        let shape = input.shape();
1002        let (batch_size, seq_len, input_features) = if self.batch_first {
1003            (shape[0], shape[1], shape[2])
1004        } else {
1005            (shape[1], shape[0], shape[2])
1006        };
1007
1008        let mut hidden_states: Vec<Variable> = (0..self.num_layers)
1009            .map(|_| {
1010                Variable::new(
1011                    zeros(&[batch_size, self.hidden_size]),
1012                    input.requires_grad(),
1013                )
1014            })
1015            .collect();
1016
1017        // Pre-compute input-to-hidden projection for layer 0 across ALL timesteps
1018        let cell0 = &self.cells[0];
1019        let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
1020        let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
1021        let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
1022        let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
1023
1024        // Hoist weight transpose + bias out of per-timestep loop
1025        let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
1026        let bias_hh_0 = cell0.bias_hh.variable();
1027        let hs = self.hidden_size;
1028
1029        for t in 0..seq_len {
1030            // Layer 0: use pre-computed ih projection + hoisted weight transpose
1031            let ih_t = ih_all_3d.select(1, t);
1032            let hidden = &hidden_states[0];
1033            let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
1034
1035            let ih_r = ih_t.narrow(1, 0, hs);
1036            let ih_z = ih_t.narrow(1, hs, hs);
1037            let ih_n = ih_t.narrow(1, 2 * hs, hs);
1038            let hh_r = hh.narrow(1, 0, hs);
1039            let hh_z = hh.narrow(1, hs, hs);
1040            let hh_n = hh.narrow(1, 2 * hs, hs);
1041
1042            let r = ih_r.add_var(&hh_r).sigmoid();
1043            let z = ih_z.add_var(&hh_z).sigmoid();
1044            let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1045            let h_minus_n = hidden.sub_var(&n);
1046            let h_new = n.add_var(&z.mul_var(&h_minus_n));
1047            hidden_states[0] = h_new.clone();
1048
1049            // Subsequent layers
1050            let mut layer_input = h_new;
1051
1052            for (layer_idx, cell) in self.cells.iter().enumerate().skip(1) {
1053                let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
1054                hidden_states[layer_idx] = new_hidden.clone();
1055                layer_input = new_hidden;
1056            }
1057        }
1058
1059        // Return last hidden state from last layer
1060        hidden_states
1061            .pop()
1062            .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
1063    }
1064
1065    /// Stack output Variables into a single [batch, seq, hidden] tensor.
1066    /// Note: This creates a new tensor without gradient connections to individual timesteps.
1067    /// For gradient flow, use forward_mean() or forward_last() instead.
1068    fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
1069        if outputs.is_empty() {
1070            return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
1071        }
1072
1073        // Unsqueeze each (batch, hidden) → (batch, 1, hidden), then cat along dim=1
1074        let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(1)).collect();
1075        let refs: Vec<&Variable> = unsqueezed.iter().collect();
1076        Variable::cat(&refs, 1)
1077    }
1078}
1079
1080// =============================================================================
1081// Tests
1082// =============================================================================
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087    use axonml_tensor::Tensor;
1088
1089    #[test]
1090    fn test_rnn_cell() {
1091        let cell = RNNCell::new(10, 20);
1092        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
1093        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
1094        let output = cell.forward_step(&input, &hidden);
1095        assert_eq!(output.shape(), vec![2, 20]);
1096    }
1097
1098    #[test]
1099    fn test_rnn() {
1100        let rnn = RNN::new(10, 20, 2);
1101        let input = Variable::new(
1102            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1103            false,
1104        );
1105        let output = rnn.forward(&input);
1106        assert_eq!(output.shape(), vec![2, 5, 20]);
1107    }
1108
1109    #[test]
1110    fn test_lstm() {
1111        let lstm = LSTM::new(10, 20, 1);
1112        let input = Variable::new(
1113            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1114            false,
1115        );
1116        let output = lstm.forward(&input);
1117        assert_eq!(output.shape(), vec![2, 5, 20]);
1118    }
1119
1120    #[test]
1121    fn test_gru_gradients_reach_parameters() {
1122        let gru = GRU::new(4, 8, 1);
1123        let input = Variable::new(
1124            Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1125            true,
1126        );
1127        let output = gru.forward(&input);
1128        println!(
1129            "Output shape: {:?}, requires_grad: {}",
1130            output.shape(),
1131            output.requires_grad()
1132        );
1133        let loss = output.sum();
1134        println!(
1135            "Loss: {:?}, requires_grad: {}",
1136            loss.data().to_vec(),
1137            loss.requires_grad()
1138        );
1139        loss.backward();
1140
1141        // Check input gradient
1142        println!(
1143            "Input grad: {:?}",
1144            input
1145                .grad()
1146                .map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>())
1147        );
1148
1149        let params = gru.parameters();
1150        println!("Number of parameters: {}", params.len());
1151        let mut has_grad = false;
1152        for (i, p) in params.iter().enumerate() {
1153            let grad = p.grad();
1154            match grad {
1155                Some(g) => {
1156                    let gv = g.to_vec();
1157                    let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1158                    println!(
1159                        "Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1160                        i,
1161                        p.shape(),
1162                        p.requires_grad(),
1163                        sum_abs
1164                    );
1165                    if sum_abs > 0.0 {
1166                        has_grad = true;
1167                    }
1168                }
1169                None => {
1170                    println!(
1171                        "Param {} shape {:?} requires_grad={}: NO GRADIENT",
1172                        i,
1173                        p.shape(),
1174                        p.requires_grad()
1175                    );
1176                }
1177            }
1178        }
1179        assert!(
1180            has_grad,
1181            "At least one GRU parameter should have non-zero gradients"
1182        );
1183    }
1184}