1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{xavier_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct RNNCell {
25 pub weight_ih: Parameter,
27 pub weight_hh: Parameter,
29 pub bias_ih: Parameter,
31 pub bias_hh: Parameter,
33 input_size: usize,
35 hidden_size: usize,
37}
38
39impl RNNCell {
40 pub fn new(input_size: usize, hidden_size: usize) -> Self {
42 Self {
43 weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
44 weight_hh: Parameter::named(
45 "weight_hh",
46 xavier_uniform(hidden_size, hidden_size),
47 true,
48 ),
49 bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
50 bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
51 input_size,
52 hidden_size,
53 }
54 }
55
56 pub fn input_size(&self) -> usize {
58 self.input_size
59 }
60
61 pub fn hidden_size(&self) -> usize {
63 self.hidden_size
64 }
65
66 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
68 let input_features = input.data().shape().last().copied().unwrap_or(0);
69 assert_eq!(
70 input_features, self.input_size,
71 "RNNCell: expected input size {}, got {}",
72 self.input_size, input_features
73 );
74 let weight_ih = self.weight_ih.variable();
76 let weight_ih_t = weight_ih.transpose(0, 1);
77 let ih = input.matmul(&weight_ih_t);
78 let bias_ih = self.bias_ih.variable();
79 let ih = ih.add_var(&bias_ih);
80
81 let weight_hh = self.weight_hh.variable();
83 let weight_hh_t = weight_hh.transpose(0, 1);
84 let hh = hidden.matmul(&weight_hh_t);
85 let bias_hh = self.bias_hh.variable();
86 let hh = hh.add_var(&bias_hh);
87
88 ih.add_var(&hh).tanh()
90 }
91}
92
93impl Module for RNNCell {
94 fn forward(&self, input: &Variable) -> Variable {
95 let batch_size = input.shape()[0];
97 let hidden = Variable::new(
98 zeros(&[batch_size, self.hidden_size]),
99 input.requires_grad(),
100 );
101 self.forward_step(input, &hidden)
102 }
103
104 fn parameters(&self) -> Vec<Parameter> {
105 vec![
106 self.weight_ih.clone(),
107 self.weight_hh.clone(),
108 self.bias_ih.clone(),
109 self.bias_hh.clone(),
110 ]
111 }
112
113 fn named_parameters(&self) -> HashMap<String, Parameter> {
114 let mut params = HashMap::new();
115 params.insert("weight_ih".to_string(), self.weight_ih.clone());
116 params.insert("weight_hh".to_string(), self.weight_hh.clone());
117 params.insert("bias_ih".to_string(), self.bias_ih.clone());
118 params.insert("bias_hh".to_string(), self.bias_hh.clone());
119 params
120 }
121
122 fn name(&self) -> &'static str {
123 "RNNCell"
124 }
125}
126
127pub struct RNN {
135 cells: Vec<RNNCell>,
137 input_size: usize,
139 hidden_size: usize,
141 num_layers: usize,
143 batch_first: bool,
145}
146
147impl RNN {
148 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
150 Self::with_options(input_size, hidden_size, num_layers, true)
151 }
152
153 pub fn with_options(
155 input_size: usize,
156 hidden_size: usize,
157 num_layers: usize,
158 batch_first: bool,
159 ) -> Self {
160 let mut cells = Vec::with_capacity(num_layers);
161
162 cells.push(RNNCell::new(input_size, hidden_size));
164
165 for _ in 1..num_layers {
167 cells.push(RNNCell::new(hidden_size, hidden_size));
168 }
169
170 Self {
171 cells,
172 input_size,
173 hidden_size,
174 num_layers,
175 batch_first,
176 }
177 }
178}
179
180impl Module for RNN {
181 fn forward(&self, input: &Variable) -> Variable {
182 let shape = input.shape();
183 let (batch_size, seq_len, _) = if self.batch_first {
184 (shape[0], shape[1], shape[2])
185 } else {
186 (shape[1], shape[0], shape[2])
187 };
188
189 let mut hiddens: Vec<Variable> = (0..self.num_layers)
191 .map(|_| {
192 Variable::new(
193 zeros(&[batch_size, self.hidden_size]),
194 input.requires_grad(),
195 )
196 })
197 .collect();
198
199 let input_data = input.data();
201 let mut outputs = Vec::with_capacity(seq_len);
202
203 for t in 0..seq_len {
204 let t_input = if self.batch_first {
206 let mut slice_data = vec![0.0f32; batch_size * self.input_size];
208 let input_vec = input_data.to_vec();
209 for b in 0..batch_size {
210 for f in 0..self.input_size {
211 let src_idx = b * seq_len * self.input_size + t * self.input_size + f;
212 let dst_idx = b * self.input_size + f;
213 slice_data[dst_idx] = input_vec[src_idx];
214 }
215 }
216 Variable::new(
217 Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
218 input.requires_grad(),
219 )
220 } else {
221 let mut slice_data = vec![0.0f32; batch_size * self.input_size];
223 let input_vec = input_data.to_vec();
224 for b in 0..batch_size {
225 for f in 0..self.input_size {
226 let src_idx = t * batch_size * self.input_size + b * self.input_size + f;
227 let dst_idx = b * self.input_size + f;
228 slice_data[dst_idx] = input_vec[src_idx];
229 }
230 }
231 Variable::new(
232 Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
233 input.requires_grad(),
234 )
235 };
236
237 let mut layer_input = t_input;
239 for (l, cell) in self.cells.iter().enumerate() {
240 hiddens[l] = cell.forward_step(&layer_input, &hiddens[l]);
241 layer_input = hiddens[l].clone();
242 }
243
244 outputs.push(hiddens[self.num_layers - 1].clone());
245 }
246
247 let time_dim = if self.batch_first { 1 } else { 0 };
249 let unsqueezed: Vec<Variable> = outputs.iter()
250 .map(|o| o.unsqueeze(time_dim))
251 .collect();
252 let refs: Vec<&Variable> = unsqueezed.iter().collect();
253 Variable::cat(&refs, time_dim)
254 }
255
256 fn parameters(&self) -> Vec<Parameter> {
257 self.cells.iter().flat_map(|c| c.parameters()).collect()
258 }
259
260 fn name(&self) -> &'static str {
261 "RNN"
262 }
263}
264
265pub struct LSTMCell {
271 pub weight_ih: Parameter,
273 pub weight_hh: Parameter,
275 pub bias_ih: Parameter,
277 pub bias_hh: Parameter,
279 input_size: usize,
281 hidden_size: usize,
283}
284
285impl LSTMCell {
286 pub fn new(input_size: usize, hidden_size: usize) -> Self {
288 Self {
290 weight_ih: Parameter::named(
291 "weight_ih",
292 xavier_uniform(input_size, 4 * hidden_size),
293 true,
294 ),
295 weight_hh: Parameter::named(
296 "weight_hh",
297 xavier_uniform(hidden_size, 4 * hidden_size),
298 true,
299 ),
300 bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
301 bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
302 input_size,
303 hidden_size,
304 }
305 }
306
307 pub fn input_size(&self) -> usize {
309 self.input_size
310 }
311
312 pub fn hidden_size(&self) -> usize {
314 self.hidden_size
315 }
316
317 pub fn forward_step(
319 &self,
320 input: &Variable,
321 hx: &(Variable, Variable),
322 ) -> (Variable, Variable) {
323 let input_features = input.data().shape().last().copied().unwrap_or(0);
324 assert_eq!(
325 input_features, self.input_size,
326 "LSTMCell: expected input size {}, got {}",
327 self.input_size, input_features
328 );
329
330 let (h, c) = hx;
331
332 let weight_ih = self.weight_ih.variable();
334 let weight_ih_t = weight_ih.transpose(0, 1);
335 let ih = input.matmul(&weight_ih_t);
336 let bias_ih = self.bias_ih.variable();
337 let ih = ih.add_var(&bias_ih);
338
339 let weight_hh = self.weight_hh.variable();
340 let weight_hh_t = weight_hh.transpose(0, 1);
341 let hh = h.matmul(&weight_hh_t);
342 let bias_hh = self.bias_hh.variable();
343 let hh = hh.add_var(&bias_hh);
344
345 let gates = ih.add_var(&hh);
346 let gates_vec = gates.data().to_vec();
347 let batch_size = input.shape()[0];
348
349 let mut i_data = vec![0.0f32; batch_size * self.hidden_size];
351 let mut f_data = vec![0.0f32; batch_size * self.hidden_size];
352 let mut g_data = vec![0.0f32; batch_size * self.hidden_size];
353 let mut o_data = vec![0.0f32; batch_size * self.hidden_size];
354
355 for b in 0..batch_size {
356 for j in 0..self.hidden_size {
357 let base = b * 4 * self.hidden_size;
358 i_data[b * self.hidden_size + j] = gates_vec[base + j];
359 f_data[b * self.hidden_size + j] = gates_vec[base + self.hidden_size + j];
360 g_data[b * self.hidden_size + j] = gates_vec[base + 2 * self.hidden_size + j];
361 o_data[b * self.hidden_size + j] = gates_vec[base + 3 * self.hidden_size + j];
362 }
363 }
364
365 let i = Variable::new(
366 Tensor::from_vec(i_data, &[batch_size, self.hidden_size]).unwrap(),
367 input.requires_grad(),
368 )
369 .sigmoid();
370 let f = Variable::new(
371 Tensor::from_vec(f_data, &[batch_size, self.hidden_size]).unwrap(),
372 input.requires_grad(),
373 )
374 .sigmoid();
375 let g = Variable::new(
376 Tensor::from_vec(g_data, &[batch_size, self.hidden_size]).unwrap(),
377 input.requires_grad(),
378 )
379 .tanh();
380 let o = Variable::new(
381 Tensor::from_vec(o_data, &[batch_size, self.hidden_size]).unwrap(),
382 input.requires_grad(),
383 )
384 .sigmoid();
385
386 let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
388
389 let h_new = o.mul_var(&c_new.tanh());
391
392 (h_new, c_new)
393 }
394}
395
396impl Module for LSTMCell {
397 fn forward(&self, input: &Variable) -> Variable {
398 let batch_size = input.shape()[0];
399 let h = Variable::new(
400 zeros(&[batch_size, self.hidden_size]),
401 input.requires_grad(),
402 );
403 let c = Variable::new(
404 zeros(&[batch_size, self.hidden_size]),
405 input.requires_grad(),
406 );
407 let (h_new, _) = self.forward_step(input, &(h, c));
408 h_new
409 }
410
411 fn parameters(&self) -> Vec<Parameter> {
412 vec![
413 self.weight_ih.clone(),
414 self.weight_hh.clone(),
415 self.bias_ih.clone(),
416 self.bias_hh.clone(),
417 ]
418 }
419
420 fn named_parameters(&self) -> HashMap<String, Parameter> {
421 let mut params = HashMap::new();
422 params.insert("weight_ih".to_string(), self.weight_ih.clone());
423 params.insert("weight_hh".to_string(), self.weight_hh.clone());
424 params.insert("bias_ih".to_string(), self.bias_ih.clone());
425 params.insert("bias_hh".to_string(), self.bias_hh.clone());
426 params
427 }
428
429 fn name(&self) -> &'static str {
430 "LSTMCell"
431 }
432}
433
434pub struct LSTM {
440 cells: Vec<LSTMCell>,
442 input_size: usize,
444 hidden_size: usize,
446 num_layers: usize,
448 batch_first: bool,
450}
451
452impl LSTM {
453 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
455 Self::with_options(input_size, hidden_size, num_layers, true)
456 }
457
458 pub fn with_options(
460 input_size: usize,
461 hidden_size: usize,
462 num_layers: usize,
463 batch_first: bool,
464 ) -> Self {
465 let mut cells = Vec::with_capacity(num_layers);
466 cells.push(LSTMCell::new(input_size, hidden_size));
467 for _ in 1..num_layers {
468 cells.push(LSTMCell::new(hidden_size, hidden_size));
469 }
470
471 Self {
472 cells,
473 input_size,
474 hidden_size,
475 num_layers,
476 batch_first,
477 }
478 }
479
480 pub fn input_size(&self) -> usize {
482 self.input_size
483 }
484
485 pub fn hidden_size(&self) -> usize {
487 self.hidden_size
488 }
489
490 pub fn num_layers(&self) -> usize {
492 self.num_layers
493 }
494}
495
496impl Module for LSTM {
497 fn forward(&self, input: &Variable) -> Variable {
498 let shape = input.shape();
501 let (batch_size, seq_len, input_features) = if self.batch_first {
502 (shape[0], shape[1], shape[2])
503 } else {
504 (shape[1], shape[0], shape[2])
505 };
506
507 let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
508 .map(|_| {
509 (
510 Variable::new(
511 zeros(&[batch_size, self.hidden_size]),
512 input.requires_grad(),
513 ),
514 Variable::new(
515 zeros(&[batch_size, self.hidden_size]),
516 input.requires_grad(),
517 ),
518 )
519 })
520 .collect();
521
522 let input_data = input.data();
523 let input_vec = input_data.to_vec();
524 let mut outputs = Vec::with_capacity(seq_len);
525
526 for t in 0..seq_len {
527 let mut slice_data = vec![0.0f32; batch_size * input_features];
528 for b in 0..batch_size {
529 for f in 0..input_features {
530 let src_idx = if self.batch_first {
531 b * seq_len * input_features + t * input_features + f
532 } else {
533 t * batch_size * input_features + b * input_features + f
534 };
535 slice_data[b * input_features + f] = input_vec[src_idx];
536 }
537 }
538
539 let mut layer_input = Variable::new(
541 Tensor::from_vec(slice_data.clone(), &[batch_size, input_features]).unwrap(),
542 input.requires_grad(),
543 );
544
545 for (l, cell) in self.cells.iter().enumerate() {
546 if l > 0 {
548 layer_input = states[l - 1].0.clone();
549 }
550 states[l] = cell.forward_step(&layer_input, &states[l]);
551 }
552
553 outputs.push(states[self.num_layers - 1].0.clone());
554 }
555
556 let time_dim = if self.batch_first { 1 } else { 0 };
558 let unsqueezed: Vec<Variable> = outputs.iter()
559 .map(|o| o.unsqueeze(time_dim))
560 .collect();
561 let refs: Vec<&Variable> = unsqueezed.iter().collect();
562 Variable::cat(&refs, time_dim)
563 }
564
565 fn parameters(&self) -> Vec<Parameter> {
566 self.cells.iter().flat_map(|c| c.parameters()).collect()
567 }
568
569 fn named_parameters(&self) -> HashMap<String, Parameter> {
570 let mut params = HashMap::new();
571 if self.cells.len() == 1 {
572 for (n, p) in self.cells[0].named_parameters() {
574 params.insert(n, p);
575 }
576 } else {
577 for (i, cell) in self.cells.iter().enumerate() {
578 for (n, p) in cell.named_parameters() {
579 params.insert(format!("cells.{i}.{n}"), p);
580 }
581 }
582 }
583 params
584 }
585
586 fn name(&self) -> &'static str {
587 "LSTM"
588 }
589}
590
591pub struct GRUCell {
603 pub weight_ih: Parameter,
605 pub weight_hh: Parameter,
607 pub bias_ih: Parameter,
609 pub bias_hh: Parameter,
611 input_size: usize,
613 hidden_size: usize,
615}
616
617impl GRUCell {
618 pub fn new(input_size: usize, hidden_size: usize) -> Self {
620 Self {
621 weight_ih: Parameter::named(
622 "weight_ih",
623 xavier_uniform(input_size, 3 * hidden_size),
624 true,
625 ),
626 weight_hh: Parameter::named(
627 "weight_hh",
628 xavier_uniform(hidden_size, 3 * hidden_size),
629 true,
630 ),
631 bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
632 bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
633 input_size,
634 hidden_size,
635 }
636 }
637
638 pub fn input_size(&self) -> usize {
640 self.input_size
641 }
642
643 pub fn hidden_size(&self) -> usize {
645 self.hidden_size
646 }
647}
648
649impl GRUCell {
650 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
660 let batch_size = input.shape()[0];
661 let hidden_size = self.hidden_size;
662
663 let weight_ih = self.weight_ih.variable();
665 let weight_hh = self.weight_hh.variable();
666 let bias_ih = self.bias_ih.variable();
667 let bias_hh = self.bias_hh.variable();
668
669 let weight_ih_t = weight_ih.transpose(0, 1);
672 let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
673
674 let weight_hh_t = weight_hh.transpose(0, 1);
677 let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
678
679 let ih_r = ih.narrow(1, 0, hidden_size);
682 let ih_z = ih.narrow(1, hidden_size, hidden_size);
683 let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
684
685 let hh_r = hh.narrow(1, 0, hidden_size);
686 let hh_z = hh.narrow(1, hidden_size, hidden_size);
687 let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
688
689 let r = ih_r.add_var(&hh_r).sigmoid();
692
693 let z = ih_z.add_var(&hh_z).sigmoid();
695
696 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
698
699 let shape = [batch_size, hidden_size];
702 let ones = Variable::new(
703 Tensor::from_vec(vec![1.0f32; batch_size * hidden_size], &shape).unwrap(),
704 false,
705 );
706 let one_minus_z = ones.sub_var(&z);
707
708 one_minus_z.mul_var(&n).add_var(&z.mul_var(hidden))
710 }
711}
712
713impl Module for GRUCell {
714 fn forward(&self, input: &Variable) -> Variable {
715 let batch_size = input.shape()[0];
716
717 let hidden = Variable::new(
719 zeros(&[batch_size, self.hidden_size]),
720 input.requires_grad(),
721 );
722
723 self.forward_step(input, &hidden)
724 }
725
726 fn parameters(&self) -> Vec<Parameter> {
727 vec![
728 self.weight_ih.clone(),
729 self.weight_hh.clone(),
730 self.bias_ih.clone(),
731 self.bias_hh.clone(),
732 ]
733 }
734
735 fn named_parameters(&self) -> HashMap<String, Parameter> {
736 let mut params = HashMap::new();
737 params.insert("weight_ih".to_string(), self.weight_ih.clone());
738 params.insert("weight_hh".to_string(), self.weight_hh.clone());
739 params.insert("bias_ih".to_string(), self.bias_ih.clone());
740 params.insert("bias_hh".to_string(), self.bias_hh.clone());
741 params
742 }
743
744 fn name(&self) -> &'static str {
745 "GRUCell"
746 }
747}
748
749pub struct GRU {
751 cells: Vec<GRUCell>,
753 hidden_size: usize,
755 num_layers: usize,
757 batch_first: bool,
759}
760
761impl GRU {
762 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
764 let mut cells = Vec::with_capacity(num_layers);
765 cells.push(GRUCell::new(input_size, hidden_size));
766 for _ in 1..num_layers {
767 cells.push(GRUCell::new(hidden_size, hidden_size));
768 }
769 Self {
770 cells,
771 hidden_size,
772 num_layers,
773 batch_first: true,
774 }
775 }
776
777 pub fn hidden_size(&self) -> usize {
779 self.hidden_size
780 }
781
782 pub fn num_layers(&self) -> usize {
784 self.num_layers
785 }
786}
787
788impl Module for GRU {
789 fn forward(&self, input: &Variable) -> Variable {
790 let shape = input.shape();
791 let (batch_size, seq_len, _input_size) = if self.batch_first {
792 (shape[0], shape[1], shape[2])
793 } else {
794 (shape[1], shape[0], shape[2])
795 };
796
797 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
799 .map(|_| {
800 Variable::new(
801 zeros(&[batch_size, self.hidden_size]),
802 input.requires_grad(),
803 )
804 })
805 .collect();
806
807 let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
809
810 for t in 0..seq_len {
812 let narrowed = input.narrow(1, t, 1);
817 let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
818
819 let mut layer_input = step_input;
821
822 for (layer_idx, cell) in self.cells.iter().enumerate() {
823 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
824
825 hidden_states[layer_idx] = new_hidden.clone();
827
828 layer_input = new_hidden;
830 }
831
832 output_vars.push(layer_input);
834 }
835
836 self.stack_outputs(&output_vars, batch_size, seq_len)
840 }
841
842 fn parameters(&self) -> Vec<Parameter> {
843 self.cells.iter().flat_map(|c| c.parameters()).collect()
844 }
845
846 fn named_parameters(&self) -> HashMap<String, Parameter> {
847 let mut params = HashMap::new();
848 if self.cells.len() == 1 {
849 for (n, p) in self.cells[0].named_parameters() {
850 params.insert(n, p);
851 }
852 } else {
853 for (i, cell) in self.cells.iter().enumerate() {
854 for (n, p) in cell.named_parameters() {
855 params.insert(format!("cells.{i}.{n}"), p);
856 }
857 }
858 }
859 params
860 }
861
862 fn name(&self) -> &'static str {
863 "GRU"
864 }
865}
866
867impl GRU {
868 pub fn forward_mean(&self, input: &Variable) -> Variable {
871 let shape = input.shape();
872 let (batch_size, seq_len, _input_size) = if self.batch_first {
873 (shape[0], shape[1], shape[2])
874 } else {
875 (shape[1], shape[0], shape[2])
876 };
877
878 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
880 .map(|_| {
881 Variable::new(
882 zeros(&[batch_size, self.hidden_size]),
883 input.requires_grad(),
884 )
885 })
886 .collect();
887
888 let mut output_sum: Option<Variable> = None;
890
891 for t in 0..seq_len {
893 let narrowed = input.narrow(1, t, 1);
896 let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
897
898 let mut layer_input = step_input;
900
901 for (layer_idx, cell) in self.cells.iter().enumerate() {
902 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
903 hidden_states[layer_idx] = new_hidden.clone();
904 layer_input = new_hidden;
905 }
906
907 output_sum = Some(match output_sum {
909 None => layer_input,
910 Some(acc) => acc.add_var(&layer_input),
911 });
912 }
913
914 match output_sum {
916 Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
917 None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
918 }
919 }
920
921 pub fn forward_last(&self, input: &Variable) -> Variable {
924 let shape = input.shape();
925 let (batch_size, seq_len, _input_size) = if self.batch_first {
926 (shape[0], shape[1], shape[2])
927 } else {
928 (shape[1], shape[0], shape[2])
929 };
930
931 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
933 .map(|_| {
934 Variable::new(
935 zeros(&[batch_size, self.hidden_size]),
936 input.requires_grad(),
937 )
938 })
939 .collect();
940
941 for t in 0..seq_len {
943 let narrowed = input.narrow(1, t, 1);
945 let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
946
947 let mut layer_input = step_input;
948
949 for (layer_idx, cell) in self.cells.iter().enumerate() {
950 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
951 hidden_states[layer_idx] = new_hidden.clone();
952 layer_input = new_hidden;
953 }
954 }
955
956 hidden_states
958 .pop()
959 .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
960 }
961
962 fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
966 if outputs.is_empty() {
967 return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
968 }
969
970 let unsqueezed: Vec<Variable> = outputs.iter()
972 .map(|o| o.unsqueeze(1))
973 .collect();
974 let refs: Vec<&Variable> = unsqueezed.iter().collect();
975 Variable::cat(&refs, 1)
976 }
977}
978
979#[cfg(test)]
984mod tests {
985 use super::*;
986
987 #[test]
988 fn test_rnn_cell() {
989 let cell = RNNCell::new(10, 20);
990 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
991 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
992 let output = cell.forward_step(&input, &hidden);
993 assert_eq!(output.shape(), vec![2, 20]);
994 }
995
996 #[test]
997 fn test_rnn() {
998 let rnn = RNN::new(10, 20, 2);
999 let input = Variable::new(
1000 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1001 false,
1002 );
1003 let output = rnn.forward(&input);
1004 assert_eq!(output.shape(), vec![2, 5, 20]);
1005 }
1006
1007 #[test]
1008 fn test_lstm() {
1009 let lstm = LSTM::new(10, 20, 1);
1010 let input = Variable::new(
1011 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1012 false,
1013 );
1014 let output = lstm.forward(&input);
1015 assert_eq!(output.shape(), vec![2, 5, 20]);
1016 }
1017
1018 #[test]
1019 fn test_gru_gradients_reach_parameters() {
1020 let gru = GRU::new(4, 8, 1);
1021 let input = Variable::new(
1022 Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1023 true,
1024 );
1025 let output = gru.forward(&input);
1026 println!("Output shape: {:?}, requires_grad: {}", output.shape(), output.requires_grad());
1027 let loss = output.sum();
1028 println!("Loss: {:?}, requires_grad: {}", loss.data().to_vec(), loss.requires_grad());
1029 loss.backward();
1030
1031 println!("Input grad: {:?}", input.grad().map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>()));
1033
1034 let params = gru.parameters();
1035 println!("Number of parameters: {}", params.len());
1036 let mut has_grad = false;
1037 for (i, p) in params.iter().enumerate() {
1038 let grad = p.grad();
1039 match grad {
1040 Some(g) => {
1041 let gv = g.to_vec();
1042 let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1043 println!("Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1044 i, p.shape(), p.requires_grad(), sum_abs);
1045 if sum_abs > 0.0 {
1046 has_grad = true;
1047 }
1048 }
1049 None => {
1050 println!("Param {} shape {:?} requires_grad={}: NO GRADIENT",
1051 i, p.shape(), p.requires_grad());
1052 }
1053 }
1054 }
1055 assert!(has_grad, "At least one GRU parameter should have non-zero gradients");
1056 }
1057}