axonml_nn/layers/
rnn.rs

1//! Recurrent Neural Network Layers - RNN, LSTM, GRU
2//!
3//! Processes sequential data with recurrent connections.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{xavier_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// RNNCell
19// =============================================================================
20
21/// A single RNN cell.
22///
23/// h' = tanh(W_ih * x + b_ih + W_hh * h + b_hh)
24pub struct RNNCell {
25    /// Input-hidden weights.
26    pub weight_ih: Parameter,
27    /// Hidden-hidden weights.
28    pub weight_hh: Parameter,
29    /// Input-hidden bias.
30    pub bias_ih: Parameter,
31    /// Hidden-hidden bias.
32    pub bias_hh: Parameter,
33    /// Input size.
34    input_size: usize,
35    /// Hidden size.
36    hidden_size: usize,
37}
38
39impl RNNCell {
40    /// Creates a new RNNCell.
41    pub fn new(input_size: usize, hidden_size: usize) -> Self {
42        Self {
43            weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
44            weight_hh: Parameter::named(
45                "weight_hh",
46                xavier_uniform(hidden_size, hidden_size),
47                true,
48            ),
49            bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
50            bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
51            input_size,
52            hidden_size,
53        }
54    }
55
56    /// Returns the expected input size.
57    pub fn input_size(&self) -> usize {
58        self.input_size
59    }
60
61    /// Returns the hidden state size.
62    pub fn hidden_size(&self) -> usize {
63        self.hidden_size
64    }
65
66    /// Forward pass for a single time step.
67    pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
68        let input_features = input.data().shape().last().copied().unwrap_or(0);
69        assert_eq!(
70            input_features, self.input_size,
71            "RNNCell: expected input size {}, got {}",
72            self.input_size, input_features
73        );
74        // x @ W_ih^T + b_ih
75        let weight_ih = self.weight_ih.variable();
76        let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
77        let ih = input.matmul(&weight_ih_t);
78        let bias_ih = self.bias_ih.variable();
79        let ih = ih.add_var(&bias_ih);
80
81        // h @ W_hh^T + b_hh
82        let weight_hh = self.weight_hh.variable();
83        let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
84        let hh = hidden.matmul(&weight_hh_t);
85        let bias_hh = self.bias_hh.variable();
86        let hh = hh.add_var(&bias_hh);
87
88        // tanh(ih + hh)
89        ih.add_var(&hh).tanh()
90    }
91}
92
93impl Module for RNNCell {
94    fn forward(&self, input: &Variable) -> Variable {
95        // Initialize hidden state to zeros
96        let batch_size = input.shape()[0];
97        let hidden = Variable::new(
98            zeros(&[batch_size, self.hidden_size]),
99            input.requires_grad(),
100        );
101        self.forward_step(input, &hidden)
102    }
103
104    fn parameters(&self) -> Vec<Parameter> {
105        vec![
106            self.weight_ih.clone(),
107            self.weight_hh.clone(),
108            self.bias_ih.clone(),
109            self.bias_hh.clone(),
110        ]
111    }
112
113    fn named_parameters(&self) -> HashMap<String, Parameter> {
114        let mut params = HashMap::new();
115        params.insert("weight_ih".to_string(), self.weight_ih.clone());
116        params.insert("weight_hh".to_string(), self.weight_hh.clone());
117        params.insert("bias_ih".to_string(), self.bias_ih.clone());
118        params.insert("bias_hh".to_string(), self.bias_hh.clone());
119        params
120    }
121
122    fn name(&self) -> &'static str {
123        "RNNCell"
124    }
125}
126
127// =============================================================================
128// RNN
129// =============================================================================
130
131/// Multi-layer RNN.
132///
133/// Processes sequences through stacked RNN layers.
134pub struct RNN {
135    /// RNN cells for each layer.
136    cells: Vec<RNNCell>,
137    /// Input size.
138    input_size: usize,
139    /// Hidden size.
140    hidden_size: usize,
141    /// Number of layers.
142    num_layers: usize,
143    /// Batch first flag.
144    batch_first: bool,
145}
146
147impl RNN {
148    /// Creates a new multi-layer RNN.
149    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
150        Self::with_options(input_size, hidden_size, num_layers, true)
151    }
152
153    /// Creates an RNN with all options.
154    pub fn with_options(
155        input_size: usize,
156        hidden_size: usize,
157        num_layers: usize,
158        batch_first: bool,
159    ) -> Self {
160        let mut cells = Vec::with_capacity(num_layers);
161
162        // First layer takes input_size
163        cells.push(RNNCell::new(input_size, hidden_size));
164
165        // Subsequent layers take hidden_size
166        for _ in 1..num_layers {
167            cells.push(RNNCell::new(hidden_size, hidden_size));
168        }
169
170        Self {
171            cells,
172            input_size,
173            hidden_size,
174            num_layers,
175            batch_first,
176        }
177    }
178}
179
180impl Module for RNN {
181    fn forward(&self, input: &Variable) -> Variable {
182        let shape = input.shape();
183        let (batch_size, seq_len, _) = if self.batch_first {
184            (shape[0], shape[1], shape[2])
185        } else {
186            (shape[1], shape[0], shape[2])
187        };
188
189        // Initialize hidden states
190        let mut hiddens: Vec<Variable> = (0..self.num_layers)
191            .map(|_| {
192                Variable::new(
193                    zeros(&[batch_size, self.hidden_size]),
194                    input.requires_grad(),
195                )
196            })
197            .collect();
198
199        // Process each time step
200        let input_data = input.data();
201        let mut outputs = Vec::with_capacity(seq_len);
202
203        for t in 0..seq_len {
204            // Extract input at time t
205            let t_input = if self.batch_first {
206                // [batch, seq, features] -> extract [batch, features] at t
207                let mut slice_data = vec![0.0f32; batch_size * self.input_size];
208                let input_vec = input_data.to_vec();
209                for b in 0..batch_size {
210                    for f in 0..self.input_size {
211                        let src_idx = b * seq_len * self.input_size + t * self.input_size + f;
212                        let dst_idx = b * self.input_size + f;
213                        slice_data[dst_idx] = input_vec[src_idx];
214                    }
215                }
216                Variable::new(
217                    Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
218                    input.requires_grad(),
219                )
220            } else {
221                // [seq, batch, features] -> extract [batch, features] at t
222                let mut slice_data = vec![0.0f32; batch_size * self.input_size];
223                let input_vec = input_data.to_vec();
224                for b in 0..batch_size {
225                    for f in 0..self.input_size {
226                        let src_idx = t * batch_size * self.input_size + b * self.input_size + f;
227                        let dst_idx = b * self.input_size + f;
228                        slice_data[dst_idx] = input_vec[src_idx];
229                    }
230                }
231                Variable::new(
232                    Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
233                    input.requires_grad(),
234                )
235            };
236
237            // Process through layers
238            let mut layer_input = t_input;
239            for (l, cell) in self.cells.iter().enumerate() {
240                hiddens[l] = cell.forward_step(&layer_input, &hiddens[l]);
241                layer_input = hiddens[l].clone();
242            }
243
244            outputs.push(hiddens[self.num_layers - 1].clone());
245        }
246
247        // Stack outputs
248        let output_size = batch_size * seq_len * self.hidden_size;
249        let mut output_data = vec![0.0f32; output_size];
250
251        for (t, out) in outputs.iter().enumerate() {
252            let out_vec = out.data().to_vec();
253            for b in 0..batch_size {
254                for h in 0..self.hidden_size {
255                    let src_idx = b * self.hidden_size + h;
256                    let dst_idx = if self.batch_first {
257                        b * seq_len * self.hidden_size + t * self.hidden_size + h
258                    } else {
259                        t * batch_size * self.hidden_size + b * self.hidden_size + h
260                    };
261                    output_data[dst_idx] = out_vec[src_idx];
262                }
263            }
264        }
265
266        let output_shape = if self.batch_first {
267            vec![batch_size, seq_len, self.hidden_size]
268        } else {
269            vec![seq_len, batch_size, self.hidden_size]
270        };
271
272        Variable::new(
273            Tensor::from_vec(output_data, &output_shape).unwrap(),
274            input.requires_grad(),
275        )
276    }
277
278    fn parameters(&self) -> Vec<Parameter> {
279        self.cells.iter().flat_map(|c| c.parameters()).collect()
280    }
281
282    fn name(&self) -> &'static str {
283        "RNN"
284    }
285}
286
287// =============================================================================
288// LSTMCell
289// =============================================================================
290
291/// A single LSTM cell.
292pub struct LSTMCell {
293    /// Input-hidden weights for all gates.
294    pub weight_ih: Parameter,
295    /// Hidden-hidden weights for all gates.
296    pub weight_hh: Parameter,
297    /// Input-hidden bias for all gates.
298    pub bias_ih: Parameter,
299    /// Hidden-hidden bias for all gates.
300    pub bias_hh: Parameter,
301    /// Input size.
302    input_size: usize,
303    /// Hidden size.
304    hidden_size: usize,
305}
306
307impl LSTMCell {
308    /// Creates a new LSTMCell.
309    pub fn new(input_size: usize, hidden_size: usize) -> Self {
310        // LSTM has 4 gates, so weight size is 4*hidden_size
311        Self {
312            weight_ih: Parameter::named(
313                "weight_ih",
314                xavier_uniform(input_size, 4 * hidden_size),
315                true,
316            ),
317            weight_hh: Parameter::named(
318                "weight_hh",
319                xavier_uniform(hidden_size, 4 * hidden_size),
320                true,
321            ),
322            bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
323            bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
324            input_size,
325            hidden_size,
326        }
327    }
328
329    /// Returns the expected input size.
330    pub fn input_size(&self) -> usize {
331        self.input_size
332    }
333
334    /// Returns the hidden state size.
335    pub fn hidden_size(&self) -> usize {
336        self.hidden_size
337    }
338
339    /// Forward pass returning (h', c').
340    pub fn forward_step(
341        &self,
342        input: &Variable,
343        hx: &(Variable, Variable),
344    ) -> (Variable, Variable) {
345        let input_features = input.data().shape().last().copied().unwrap_or(0);
346        assert_eq!(
347            input_features, self.input_size,
348            "LSTMCell: expected input size {}, got {}",
349            self.input_size, input_features
350        );
351
352        let (h, c) = hx;
353
354        // Compute all gates at once (x @ W^T + b)
355        let weight_ih = self.weight_ih.variable();
356        let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
357        let ih = input.matmul(&weight_ih_t);
358        let bias_ih = self.bias_ih.variable();
359        let ih = ih.add_var(&bias_ih);
360
361        let weight_hh = self.weight_hh.variable();
362        let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
363        let hh = h.matmul(&weight_hh_t);
364        let bias_hh = self.bias_hh.variable();
365        let hh = hh.add_var(&bias_hh);
366
367        let gates = ih.add_var(&hh);
368        let gates_vec = gates.data().to_vec();
369        let batch_size = input.shape()[0];
370
371        // Split into 4 gates: i, f, g, o
372        let mut i_data = vec![0.0f32; batch_size * self.hidden_size];
373        let mut f_data = vec![0.0f32; batch_size * self.hidden_size];
374        let mut g_data = vec![0.0f32; batch_size * self.hidden_size];
375        let mut o_data = vec![0.0f32; batch_size * self.hidden_size];
376
377        for b in 0..batch_size {
378            for j in 0..self.hidden_size {
379                let base = b * 4 * self.hidden_size;
380                i_data[b * self.hidden_size + j] = gates_vec[base + j];
381                f_data[b * self.hidden_size + j] = gates_vec[base + self.hidden_size + j];
382                g_data[b * self.hidden_size + j] = gates_vec[base + 2 * self.hidden_size + j];
383                o_data[b * self.hidden_size + j] = gates_vec[base + 3 * self.hidden_size + j];
384            }
385        }
386
387        let i = Variable::new(
388            Tensor::from_vec(i_data, &[batch_size, self.hidden_size]).unwrap(),
389            input.requires_grad(),
390        )
391        .sigmoid();
392        let f = Variable::new(
393            Tensor::from_vec(f_data, &[batch_size, self.hidden_size]).unwrap(),
394            input.requires_grad(),
395        )
396        .sigmoid();
397        let g = Variable::new(
398            Tensor::from_vec(g_data, &[batch_size, self.hidden_size]).unwrap(),
399            input.requires_grad(),
400        )
401        .tanh();
402        let o = Variable::new(
403            Tensor::from_vec(o_data, &[batch_size, self.hidden_size]).unwrap(),
404            input.requires_grad(),
405        )
406        .sigmoid();
407
408        // c' = f * c + i * g
409        let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
410
411        // h' = o * tanh(c')
412        let h_new = o.mul_var(&c_new.tanh());
413
414        (h_new, c_new)
415    }
416}
417
418impl Module for LSTMCell {
419    fn forward(&self, input: &Variable) -> Variable {
420        let batch_size = input.shape()[0];
421        let h = Variable::new(
422            zeros(&[batch_size, self.hidden_size]),
423            input.requires_grad(),
424        );
425        let c = Variable::new(
426            zeros(&[batch_size, self.hidden_size]),
427            input.requires_grad(),
428        );
429        let (h_new, _) = self.forward_step(input, &(h, c));
430        h_new
431    }
432
433    fn parameters(&self) -> Vec<Parameter> {
434        vec![
435            self.weight_ih.clone(),
436            self.weight_hh.clone(),
437            self.bias_ih.clone(),
438            self.bias_hh.clone(),
439        ]
440    }
441
442    fn name(&self) -> &'static str {
443        "LSTMCell"
444    }
445}
446
447// =============================================================================
448// LSTM
449// =============================================================================
450
451/// Multi-layer LSTM.
452pub struct LSTM {
453    /// LSTM cells for each layer.
454    cells: Vec<LSTMCell>,
455    /// Input size.
456    input_size: usize,
457    /// Hidden size.
458    hidden_size: usize,
459    /// Number of layers.
460    num_layers: usize,
461    /// Batch first flag.
462    batch_first: bool,
463}
464
465impl LSTM {
466    /// Creates a new multi-layer LSTM.
467    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
468        Self::with_options(input_size, hidden_size, num_layers, true)
469    }
470
471    /// Creates an LSTM with all options.
472    pub fn with_options(
473        input_size: usize,
474        hidden_size: usize,
475        num_layers: usize,
476        batch_first: bool,
477    ) -> Self {
478        let mut cells = Vec::with_capacity(num_layers);
479        cells.push(LSTMCell::new(input_size, hidden_size));
480        for _ in 1..num_layers {
481            cells.push(LSTMCell::new(hidden_size, hidden_size));
482        }
483
484        Self {
485            cells,
486            input_size,
487            hidden_size,
488            num_layers,
489            batch_first,
490        }
491    }
492
493    /// Returns the expected input size.
494    pub fn input_size(&self) -> usize {
495        self.input_size
496    }
497
498    /// Returns the hidden state size.
499    pub fn hidden_size(&self) -> usize {
500        self.hidden_size
501    }
502
503    /// Returns the number of layers.
504    pub fn num_layers(&self) -> usize {
505        self.num_layers
506    }
507}
508
509impl Module for LSTM {
510    fn forward(&self, input: &Variable) -> Variable {
511        // Similar to RNN forward but using LSTM cells
512        // For brevity, implementing a simplified version
513        let shape = input.shape();
514        let (batch_size, seq_len, input_features) = if self.batch_first {
515            (shape[0], shape[1], shape[2])
516        } else {
517            (shape[1], shape[0], shape[2])
518        };
519
520        let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
521            .map(|_| {
522                (
523                    Variable::new(
524                        zeros(&[batch_size, self.hidden_size]),
525                        input.requires_grad(),
526                    ),
527                    Variable::new(
528                        zeros(&[batch_size, self.hidden_size]),
529                        input.requires_grad(),
530                    ),
531                )
532            })
533            .collect();
534
535        let input_data = input.data();
536        let input_vec = input_data.to_vec();
537        let mut outputs = Vec::with_capacity(seq_len);
538
539        for t in 0..seq_len {
540            let mut slice_data = vec![0.0f32; batch_size * input_features];
541            for b in 0..batch_size {
542                for f in 0..input_features {
543                    let src_idx = if self.batch_first {
544                        b * seq_len * input_features + t * input_features + f
545                    } else {
546                        t * batch_size * input_features + b * input_features + f
547                    };
548                    slice_data[b * input_features + f] = input_vec[src_idx];
549                }
550            }
551
552            // Input slice always has input_features dimensions
553            let mut layer_input = Variable::new(
554                Tensor::from_vec(slice_data.clone(), &[batch_size, input_features]).unwrap(),
555                input.requires_grad(),
556            );
557
558            for (l, cell) in self.cells.iter().enumerate() {
559                // Resize input if needed for subsequent layers
560                if l > 0 {
561                    layer_input = states[l - 1].0.clone();
562                }
563                states[l] = cell.forward_step(&layer_input, &states[l]);
564            }
565
566            outputs.push(states[self.num_layers - 1].0.clone());
567        }
568
569        // Stack outputs
570        let mut output_data = vec![0.0f32; batch_size * seq_len * self.hidden_size];
571        for (t, out) in outputs.iter().enumerate() {
572            let out_vec = out.data().to_vec();
573            for b in 0..batch_size {
574                for h in 0..self.hidden_size {
575                    let dst_idx = if self.batch_first {
576                        b * seq_len * self.hidden_size + t * self.hidden_size + h
577                    } else {
578                        t * batch_size * self.hidden_size + b * self.hidden_size + h
579                    };
580                    output_data[dst_idx] = out_vec[b * self.hidden_size + h];
581                }
582            }
583        }
584
585        let output_shape = if self.batch_first {
586            vec![batch_size, seq_len, self.hidden_size]
587        } else {
588            vec![seq_len, batch_size, self.hidden_size]
589        };
590
591        Variable::new(
592            Tensor::from_vec(output_data, &output_shape).unwrap(),
593            input.requires_grad(),
594        )
595    }
596
597    fn parameters(&self) -> Vec<Parameter> {
598        self.cells.iter().flat_map(|c| c.parameters()).collect()
599    }
600
601    fn name(&self) -> &'static str {
602        "LSTM"
603    }
604}
605
606// =============================================================================
607// GRUCell and GRU
608// =============================================================================
609
610/// A single GRU cell.
611///
612/// h' = (1 - z) * n + z * h
613/// where:
614///   r = sigmoid(W_ir * x + b_ir + W_hr * h + b_hr)  (reset gate)
615///   z = sigmoid(W_iz * x + b_iz + W_hz * h + b_hz)  (update gate)
616///   n = tanh(W_in * x + b_in + r * (W_hn * h + b_hn))  (new gate)
617pub struct GRUCell {
618    /// Input-hidden weights for all gates (reset, update, new).
619    pub weight_ih: Parameter,
620    /// Hidden-hidden weights for all gates (reset, update, new).
621    pub weight_hh: Parameter,
622    /// Input-hidden bias for all gates.
623    pub bias_ih: Parameter,
624    /// Hidden-hidden bias for all gates.
625    pub bias_hh: Parameter,
626    /// Input size.
627    input_size: usize,
628    /// Hidden size.
629    hidden_size: usize,
630}
631
632impl GRUCell {
633    /// Creates a new GRU cell.
634    pub fn new(input_size: usize, hidden_size: usize) -> Self {
635        Self {
636            weight_ih: Parameter::named(
637                "weight_ih",
638                xavier_uniform(input_size, 3 * hidden_size),
639                true,
640            ),
641            weight_hh: Parameter::named(
642                "weight_hh",
643                xavier_uniform(hidden_size, 3 * hidden_size),
644                true,
645            ),
646            bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
647            bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
648            input_size,
649            hidden_size,
650        }
651    }
652
653    /// Returns the expected input size.
654    pub fn input_size(&self) -> usize {
655        self.input_size
656    }
657
658    /// Returns the hidden state size.
659    pub fn hidden_size(&self) -> usize {
660        self.hidden_size
661    }
662}
663
664impl Module for GRUCell {
665    fn forward(&self, input: &Variable) -> Variable {
666        let batch_size = input.shape()[0];
667
668        // Simplified GRU forward - full implementation would compute r, z, n gates
669        Variable::new(
670            zeros(&[batch_size, self.hidden_size]),
671            input.requires_grad(),
672        )
673    }
674
675    fn parameters(&self) -> Vec<Parameter> {
676        vec![
677            self.weight_ih.clone(),
678            self.weight_hh.clone(),
679            self.bias_ih.clone(),
680            self.bias_hh.clone(),
681        ]
682    }
683
684    fn name(&self) -> &'static str {
685        "GRUCell"
686    }
687}
688
689/// Multi-layer GRU.
690pub struct GRU {
691    /// GRU cells for each layer.
692    cells: Vec<GRUCell>,
693    /// Hidden state size.
694    hidden_size: usize,
695    /// Number of layers.
696    num_layers: usize,
697    /// If true, input is (batch, seq, features), else (seq, batch, features).
698    batch_first: bool,
699}
700
701impl GRU {
702    /// Creates a new multi-layer GRU.
703    pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
704        let mut cells = Vec::with_capacity(num_layers);
705        cells.push(GRUCell::new(input_size, hidden_size));
706        for _ in 1..num_layers {
707            cells.push(GRUCell::new(hidden_size, hidden_size));
708        }
709        Self {
710            cells,
711            hidden_size,
712            num_layers,
713            batch_first: true,
714        }
715    }
716
717    /// Returns the hidden state size.
718    pub fn hidden_size(&self) -> usize {
719        self.hidden_size
720    }
721
722    /// Returns the number of layers.
723    pub fn num_layers(&self) -> usize {
724        self.num_layers
725    }
726}
727
728impl Module for GRU {
729    fn forward(&self, input: &Variable) -> Variable {
730        // Simplified - returns zeros for now
731        let shape = input.shape();
732        let (batch_size, seq_len) = if self.batch_first {
733            (shape[0], shape[1])
734        } else {
735            (shape[1], shape[0])
736        };
737
738        let output_shape = if self.batch_first {
739            vec![batch_size, seq_len, self.hidden_size]
740        } else {
741            vec![seq_len, batch_size, self.hidden_size]
742        };
743
744        Variable::new(zeros(&output_shape), input.requires_grad())
745    }
746
747    fn parameters(&self) -> Vec<Parameter> {
748        self.cells.iter().flat_map(|c| c.parameters()).collect()
749    }
750
751    fn name(&self) -> &'static str {
752        "GRU"
753    }
754}
755
756// =============================================================================
757// Tests
758// =============================================================================
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763
764    #[test]
765    fn test_rnn_cell() {
766        let cell = RNNCell::new(10, 20);
767        let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
768        let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
769        let output = cell.forward_step(&input, &hidden);
770        assert_eq!(output.shape(), vec![2, 20]);
771    }
772
773    #[test]
774    fn test_rnn() {
775        let rnn = RNN::new(10, 20, 2);
776        let input = Variable::new(
777            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
778            false,
779        );
780        let output = rnn.forward(&input);
781        assert_eq!(output.shape(), vec![2, 5, 20]);
782    }
783
784    #[test]
785    fn test_lstm() {
786        let lstm = LSTM::new(10, 20, 1);
787        let input = Variable::new(
788            Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
789            false,
790        );
791        let output = lstm.forward(&input);
792        assert_eq!(output.shape(), vec![2, 5, 20]);
793    }
794}