1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{xavier_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct RNNCell {
25 pub weight_ih: Parameter,
27 pub weight_hh: Parameter,
29 pub bias_ih: Parameter,
31 pub bias_hh: Parameter,
33 input_size: usize,
35 hidden_size: usize,
37}
38
39impl RNNCell {
40 pub fn new(input_size: usize, hidden_size: usize) -> Self {
42 Self {
43 weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
44 weight_hh: Parameter::named(
45 "weight_hh",
46 xavier_uniform(hidden_size, hidden_size),
47 true,
48 ),
49 bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
50 bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
51 input_size,
52 hidden_size,
53 }
54 }
55
56 pub fn input_size(&self) -> usize {
58 self.input_size
59 }
60
61 pub fn hidden_size(&self) -> usize {
63 self.hidden_size
64 }
65
66 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
68 let input_features = input.data().shape().last().copied().unwrap_or(0);
69 assert_eq!(
70 input_features, self.input_size,
71 "RNNCell: expected input size {}, got {}",
72 self.input_size, input_features
73 );
74 let weight_ih = self.weight_ih.variable();
76 let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
77 let ih = input.matmul(&weight_ih_t);
78 let bias_ih = self.bias_ih.variable();
79 let ih = ih.add_var(&bias_ih);
80
81 let weight_hh = self.weight_hh.variable();
83 let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
84 let hh = hidden.matmul(&weight_hh_t);
85 let bias_hh = self.bias_hh.variable();
86 let hh = hh.add_var(&bias_hh);
87
88 ih.add_var(&hh).tanh()
90 }
91}
92
93impl Module for RNNCell {
94 fn forward(&self, input: &Variable) -> Variable {
95 let batch_size = input.shape()[0];
97 let hidden = Variable::new(
98 zeros(&[batch_size, self.hidden_size]),
99 input.requires_grad(),
100 );
101 self.forward_step(input, &hidden)
102 }
103
104 fn parameters(&self) -> Vec<Parameter> {
105 vec![
106 self.weight_ih.clone(),
107 self.weight_hh.clone(),
108 self.bias_ih.clone(),
109 self.bias_hh.clone(),
110 ]
111 }
112
113 fn named_parameters(&self) -> HashMap<String, Parameter> {
114 let mut params = HashMap::new();
115 params.insert("weight_ih".to_string(), self.weight_ih.clone());
116 params.insert("weight_hh".to_string(), self.weight_hh.clone());
117 params.insert("bias_ih".to_string(), self.bias_ih.clone());
118 params.insert("bias_hh".to_string(), self.bias_hh.clone());
119 params
120 }
121
122 fn name(&self) -> &'static str {
123 "RNNCell"
124 }
125}
126
127pub struct RNN {
135 cells: Vec<RNNCell>,
137 input_size: usize,
139 hidden_size: usize,
141 num_layers: usize,
143 batch_first: bool,
145}
146
147impl RNN {
148 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
150 Self::with_options(input_size, hidden_size, num_layers, true)
151 }
152
153 pub fn with_options(
155 input_size: usize,
156 hidden_size: usize,
157 num_layers: usize,
158 batch_first: bool,
159 ) -> Self {
160 let mut cells = Vec::with_capacity(num_layers);
161
162 cells.push(RNNCell::new(input_size, hidden_size));
164
165 for _ in 1..num_layers {
167 cells.push(RNNCell::new(hidden_size, hidden_size));
168 }
169
170 Self {
171 cells,
172 input_size,
173 hidden_size,
174 num_layers,
175 batch_first,
176 }
177 }
178}
179
180impl Module for RNN {
181 fn forward(&self, input: &Variable) -> Variable {
182 let shape = input.shape();
183 let (batch_size, seq_len, _) = if self.batch_first {
184 (shape[0], shape[1], shape[2])
185 } else {
186 (shape[1], shape[0], shape[2])
187 };
188
189 let mut hiddens: Vec<Variable> = (0..self.num_layers)
191 .map(|_| {
192 Variable::new(
193 zeros(&[batch_size, self.hidden_size]),
194 input.requires_grad(),
195 )
196 })
197 .collect();
198
199 let input_data = input.data();
201 let mut outputs = Vec::with_capacity(seq_len);
202
203 for t in 0..seq_len {
204 let t_input = if self.batch_first {
206 let mut slice_data = vec![0.0f32; batch_size * self.input_size];
208 let input_vec = input_data.to_vec();
209 for b in 0..batch_size {
210 for f in 0..self.input_size {
211 let src_idx = b * seq_len * self.input_size + t * self.input_size + f;
212 let dst_idx = b * self.input_size + f;
213 slice_data[dst_idx] = input_vec[src_idx];
214 }
215 }
216 Variable::new(
217 Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
218 input.requires_grad(),
219 )
220 } else {
221 let mut slice_data = vec![0.0f32; batch_size * self.input_size];
223 let input_vec = input_data.to_vec();
224 for b in 0..batch_size {
225 for f in 0..self.input_size {
226 let src_idx = t * batch_size * self.input_size + b * self.input_size + f;
227 let dst_idx = b * self.input_size + f;
228 slice_data[dst_idx] = input_vec[src_idx];
229 }
230 }
231 Variable::new(
232 Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
233 input.requires_grad(),
234 )
235 };
236
237 let mut layer_input = t_input;
239 for (l, cell) in self.cells.iter().enumerate() {
240 hiddens[l] = cell.forward_step(&layer_input, &hiddens[l]);
241 layer_input = hiddens[l].clone();
242 }
243
244 outputs.push(hiddens[self.num_layers - 1].clone());
245 }
246
247 let output_size = batch_size * seq_len * self.hidden_size;
249 let mut output_data = vec![0.0f32; output_size];
250
251 for (t, out) in outputs.iter().enumerate() {
252 let out_vec = out.data().to_vec();
253 for b in 0..batch_size {
254 for h in 0..self.hidden_size {
255 let src_idx = b * self.hidden_size + h;
256 let dst_idx = if self.batch_first {
257 b * seq_len * self.hidden_size + t * self.hidden_size + h
258 } else {
259 t * batch_size * self.hidden_size + b * self.hidden_size + h
260 };
261 output_data[dst_idx] = out_vec[src_idx];
262 }
263 }
264 }
265
266 let output_shape = if self.batch_first {
267 vec![batch_size, seq_len, self.hidden_size]
268 } else {
269 vec![seq_len, batch_size, self.hidden_size]
270 };
271
272 Variable::new(
273 Tensor::from_vec(output_data, &output_shape).unwrap(),
274 input.requires_grad(),
275 )
276 }
277
278 fn parameters(&self) -> Vec<Parameter> {
279 self.cells.iter().flat_map(|c| c.parameters()).collect()
280 }
281
282 fn name(&self) -> &'static str {
283 "RNN"
284 }
285}
286
287pub struct LSTMCell {
293 pub weight_ih: Parameter,
295 pub weight_hh: Parameter,
297 pub bias_ih: Parameter,
299 pub bias_hh: Parameter,
301 input_size: usize,
303 hidden_size: usize,
305}
306
307impl LSTMCell {
308 pub fn new(input_size: usize, hidden_size: usize) -> Self {
310 Self {
312 weight_ih: Parameter::named(
313 "weight_ih",
314 xavier_uniform(input_size, 4 * hidden_size),
315 true,
316 ),
317 weight_hh: Parameter::named(
318 "weight_hh",
319 xavier_uniform(hidden_size, 4 * hidden_size),
320 true,
321 ),
322 bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
323 bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
324 input_size,
325 hidden_size,
326 }
327 }
328
329 pub fn input_size(&self) -> usize {
331 self.input_size
332 }
333
334 pub fn hidden_size(&self) -> usize {
336 self.hidden_size
337 }
338
339 pub fn forward_step(
341 &self,
342 input: &Variable,
343 hx: &(Variable, Variable),
344 ) -> (Variable, Variable) {
345 let input_features = input.data().shape().last().copied().unwrap_or(0);
346 assert_eq!(
347 input_features, self.input_size,
348 "LSTMCell: expected input size {}, got {}",
349 self.input_size, input_features
350 );
351
352 let (h, c) = hx;
353
354 let weight_ih = self.weight_ih.variable();
356 let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
357 let ih = input.matmul(&weight_ih_t);
358 let bias_ih = self.bias_ih.variable();
359 let ih = ih.add_var(&bias_ih);
360
361 let weight_hh = self.weight_hh.variable();
362 let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
363 let hh = h.matmul(&weight_hh_t);
364 let bias_hh = self.bias_hh.variable();
365 let hh = hh.add_var(&bias_hh);
366
367 let gates = ih.add_var(&hh);
368 let gates_vec = gates.data().to_vec();
369 let batch_size = input.shape()[0];
370
371 let mut i_data = vec![0.0f32; batch_size * self.hidden_size];
373 let mut f_data = vec![0.0f32; batch_size * self.hidden_size];
374 let mut g_data = vec![0.0f32; batch_size * self.hidden_size];
375 let mut o_data = vec![0.0f32; batch_size * self.hidden_size];
376
377 for b in 0..batch_size {
378 for j in 0..self.hidden_size {
379 let base = b * 4 * self.hidden_size;
380 i_data[b * self.hidden_size + j] = gates_vec[base + j];
381 f_data[b * self.hidden_size + j] = gates_vec[base + self.hidden_size + j];
382 g_data[b * self.hidden_size + j] = gates_vec[base + 2 * self.hidden_size + j];
383 o_data[b * self.hidden_size + j] = gates_vec[base + 3 * self.hidden_size + j];
384 }
385 }
386
387 let i = Variable::new(
388 Tensor::from_vec(i_data, &[batch_size, self.hidden_size]).unwrap(),
389 input.requires_grad(),
390 )
391 .sigmoid();
392 let f = Variable::new(
393 Tensor::from_vec(f_data, &[batch_size, self.hidden_size]).unwrap(),
394 input.requires_grad(),
395 )
396 .sigmoid();
397 let g = Variable::new(
398 Tensor::from_vec(g_data, &[batch_size, self.hidden_size]).unwrap(),
399 input.requires_grad(),
400 )
401 .tanh();
402 let o = Variable::new(
403 Tensor::from_vec(o_data, &[batch_size, self.hidden_size]).unwrap(),
404 input.requires_grad(),
405 )
406 .sigmoid();
407
408 let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
410
411 let h_new = o.mul_var(&c_new.tanh());
413
414 (h_new, c_new)
415 }
416}
417
418impl Module for LSTMCell {
419 fn forward(&self, input: &Variable) -> Variable {
420 let batch_size = input.shape()[0];
421 let h = Variable::new(
422 zeros(&[batch_size, self.hidden_size]),
423 input.requires_grad(),
424 );
425 let c = Variable::new(
426 zeros(&[batch_size, self.hidden_size]),
427 input.requires_grad(),
428 );
429 let (h_new, _) = self.forward_step(input, &(h, c));
430 h_new
431 }
432
433 fn parameters(&self) -> Vec<Parameter> {
434 vec![
435 self.weight_ih.clone(),
436 self.weight_hh.clone(),
437 self.bias_ih.clone(),
438 self.bias_hh.clone(),
439 ]
440 }
441
442 fn name(&self) -> &'static str {
443 "LSTMCell"
444 }
445}
446
447pub struct LSTM {
453 cells: Vec<LSTMCell>,
455 input_size: usize,
457 hidden_size: usize,
459 num_layers: usize,
461 batch_first: bool,
463}
464
465impl LSTM {
466 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
468 Self::with_options(input_size, hidden_size, num_layers, true)
469 }
470
471 pub fn with_options(
473 input_size: usize,
474 hidden_size: usize,
475 num_layers: usize,
476 batch_first: bool,
477 ) -> Self {
478 let mut cells = Vec::with_capacity(num_layers);
479 cells.push(LSTMCell::new(input_size, hidden_size));
480 for _ in 1..num_layers {
481 cells.push(LSTMCell::new(hidden_size, hidden_size));
482 }
483
484 Self {
485 cells,
486 input_size,
487 hidden_size,
488 num_layers,
489 batch_first,
490 }
491 }
492
493 pub fn input_size(&self) -> usize {
495 self.input_size
496 }
497
498 pub fn hidden_size(&self) -> usize {
500 self.hidden_size
501 }
502
503 pub fn num_layers(&self) -> usize {
505 self.num_layers
506 }
507}
508
509impl Module for LSTM {
510 fn forward(&self, input: &Variable) -> Variable {
511 let shape = input.shape();
514 let (batch_size, seq_len, input_features) = if self.batch_first {
515 (shape[0], shape[1], shape[2])
516 } else {
517 (shape[1], shape[0], shape[2])
518 };
519
520 let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
521 .map(|_| {
522 (
523 Variable::new(
524 zeros(&[batch_size, self.hidden_size]),
525 input.requires_grad(),
526 ),
527 Variable::new(
528 zeros(&[batch_size, self.hidden_size]),
529 input.requires_grad(),
530 ),
531 )
532 })
533 .collect();
534
535 let input_data = input.data();
536 let input_vec = input_data.to_vec();
537 let mut outputs = Vec::with_capacity(seq_len);
538
539 for t in 0..seq_len {
540 let mut slice_data = vec![0.0f32; batch_size * input_features];
541 for b in 0..batch_size {
542 for f in 0..input_features {
543 let src_idx = if self.batch_first {
544 b * seq_len * input_features + t * input_features + f
545 } else {
546 t * batch_size * input_features + b * input_features + f
547 };
548 slice_data[b * input_features + f] = input_vec[src_idx];
549 }
550 }
551
552 let mut layer_input = Variable::new(
554 Tensor::from_vec(slice_data.clone(), &[batch_size, input_features]).unwrap(),
555 input.requires_grad(),
556 );
557
558 for (l, cell) in self.cells.iter().enumerate() {
559 if l > 0 {
561 layer_input = states[l - 1].0.clone();
562 }
563 states[l] = cell.forward_step(&layer_input, &states[l]);
564 }
565
566 outputs.push(states[self.num_layers - 1].0.clone());
567 }
568
569 let mut output_data = vec![0.0f32; batch_size * seq_len * self.hidden_size];
571 for (t, out) in outputs.iter().enumerate() {
572 let out_vec = out.data().to_vec();
573 for b in 0..batch_size {
574 for h in 0..self.hidden_size {
575 let dst_idx = if self.batch_first {
576 b * seq_len * self.hidden_size + t * self.hidden_size + h
577 } else {
578 t * batch_size * self.hidden_size + b * self.hidden_size + h
579 };
580 output_data[dst_idx] = out_vec[b * self.hidden_size + h];
581 }
582 }
583 }
584
585 let output_shape = if self.batch_first {
586 vec![batch_size, seq_len, self.hidden_size]
587 } else {
588 vec![seq_len, batch_size, self.hidden_size]
589 };
590
591 Variable::new(
592 Tensor::from_vec(output_data, &output_shape).unwrap(),
593 input.requires_grad(),
594 )
595 }
596
597 fn parameters(&self) -> Vec<Parameter> {
598 self.cells.iter().flat_map(|c| c.parameters()).collect()
599 }
600
601 fn name(&self) -> &'static str {
602 "LSTM"
603 }
604}
605
606pub struct GRUCell {
618 pub weight_ih: Parameter,
620 pub weight_hh: Parameter,
622 pub bias_ih: Parameter,
624 pub bias_hh: Parameter,
626 input_size: usize,
628 hidden_size: usize,
630}
631
632impl GRUCell {
633 pub fn new(input_size: usize, hidden_size: usize) -> Self {
635 Self {
636 weight_ih: Parameter::named(
637 "weight_ih",
638 xavier_uniform(input_size, 3 * hidden_size),
639 true,
640 ),
641 weight_hh: Parameter::named(
642 "weight_hh",
643 xavier_uniform(hidden_size, 3 * hidden_size),
644 true,
645 ),
646 bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
647 bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
648 input_size,
649 hidden_size,
650 }
651 }
652
653 pub fn input_size(&self) -> usize {
655 self.input_size
656 }
657
658 pub fn hidden_size(&self) -> usize {
660 self.hidden_size
661 }
662}
663
664impl GRUCell {
665 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
675 let batch_size = input.shape()[0];
676 let hidden_size = self.hidden_size;
677
678 let weight_ih = self.weight_ih.variable();
680 let weight_hh = self.weight_hh.variable();
681 let bias_ih = self.bias_ih.variable();
682 let bias_hh = self.bias_hh.variable();
683
684 let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
687 let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
688
689 let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
692 let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
693
694 let ih_r = ih.narrow(1, 0, hidden_size);
697 let ih_z = ih.narrow(1, hidden_size, hidden_size);
698 let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
699
700 let hh_r = hh.narrow(1, 0, hidden_size);
701 let hh_z = hh.narrow(1, hidden_size, hidden_size);
702 let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
703
704 let r = ih_r.add_var(&hh_r).sigmoid();
707
708 let z = ih_z.add_var(&hh_z).sigmoid();
710
711 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
713
714 let shape = [batch_size, hidden_size];
717 let ones = Variable::new(
718 Tensor::from_vec(vec![1.0f32; batch_size * hidden_size], &shape).unwrap(),
719 false,
720 );
721 let one_minus_z = ones.sub_var(&z);
722
723 one_minus_z.mul_var(&n).add_var(&z.mul_var(hidden))
725 }
726}
727
728
729impl Module for GRUCell {
730 fn forward(&self, input: &Variable) -> Variable {
731 let batch_size = input.shape()[0];
732
733 let hidden = Variable::new(
735 zeros(&[batch_size, self.hidden_size]),
736 input.requires_grad(),
737 );
738
739 self.forward_step(input, &hidden)
740 }
741
742 fn parameters(&self) -> Vec<Parameter> {
743 vec![
744 self.weight_ih.clone(),
745 self.weight_hh.clone(),
746 self.bias_ih.clone(),
747 self.bias_hh.clone(),
748 ]
749 }
750
751 fn name(&self) -> &'static str {
752 "GRUCell"
753 }
754}
755
756pub struct GRU {
758 cells: Vec<GRUCell>,
760 hidden_size: usize,
762 num_layers: usize,
764 batch_first: bool,
766}
767
768impl GRU {
769 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
771 let mut cells = Vec::with_capacity(num_layers);
772 cells.push(GRUCell::new(input_size, hidden_size));
773 for _ in 1..num_layers {
774 cells.push(GRUCell::new(hidden_size, hidden_size));
775 }
776 Self {
777 cells,
778 hidden_size,
779 num_layers,
780 batch_first: true,
781 }
782 }
783
784 pub fn hidden_size(&self) -> usize {
786 self.hidden_size
787 }
788
789 pub fn num_layers(&self) -> usize {
791 self.num_layers
792 }
793}
794
795impl Module for GRU {
796 fn forward(&self, input: &Variable) -> Variable {
797 let shape = input.shape();
798 let (batch_size, seq_len, _input_size) = if self.batch_first {
799 (shape[0], shape[1], shape[2])
800 } else {
801 (shape[1], shape[0], shape[2])
802 };
803
804 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
806 .map(|_| {
807 Variable::new(
808 zeros(&[batch_size, self.hidden_size]),
809 input.requires_grad(),
810 )
811 })
812 .collect();
813
814 let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
816
817 for t in 0..seq_len {
819 let narrowed = input.narrow(1, t, 1);
824 let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
825
826 let mut layer_input = step_input;
828
829 for (layer_idx, cell) in self.cells.iter().enumerate() {
830 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
831
832 hidden_states[layer_idx] = new_hidden.clone();
834
835 layer_input = new_hidden;
837 }
838
839 output_vars.push(layer_input);
841 }
842
843 self.stack_outputs(&output_vars, batch_size, seq_len)
847 }
848
849 fn parameters(&self) -> Vec<Parameter> {
850 self.cells.iter().flat_map(|c| c.parameters()).collect()
851 }
852
853 fn name(&self) -> &'static str {
854 "GRU"
855 }
856}
857
858impl GRU {
859 pub fn forward_mean(&self, input: &Variable) -> Variable {
862 let shape = input.shape();
863 let (batch_size, seq_len, _input_size) = if self.batch_first {
864 (shape[0], shape[1], shape[2])
865 } else {
866 (shape[1], shape[0], shape[2])
867 };
868
869 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
871 .map(|_| {
872 Variable::new(
873 zeros(&[batch_size, self.hidden_size]),
874 input.requires_grad(),
875 )
876 })
877 .collect();
878
879 let mut output_sum: Option<Variable> = None;
881
882 for t in 0..seq_len {
884 let narrowed = input.narrow(1, t, 1);
887 let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
888
889 let mut layer_input = step_input;
891
892 for (layer_idx, cell) in self.cells.iter().enumerate() {
893 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
894 hidden_states[layer_idx] = new_hidden.clone();
895 layer_input = new_hidden;
896 }
897
898 output_sum = Some(match output_sum {
900 None => layer_input,
901 Some(acc) => acc.add_var(&layer_input),
902 });
903 }
904
905 match output_sum {
907 Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
908 None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
909 }
910 }
911
912 pub fn forward_last(&self, input: &Variable) -> Variable {
915 let shape = input.shape();
916 let (batch_size, seq_len, _input_size) = if self.batch_first {
917 (shape[0], shape[1], shape[2])
918 } else {
919 (shape[1], shape[0], shape[2])
920 };
921
922 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
924 .map(|_| {
925 Variable::new(
926 zeros(&[batch_size, self.hidden_size]),
927 input.requires_grad(),
928 )
929 })
930 .collect();
931
932 for t in 0..seq_len {
934 let narrowed = input.narrow(1, t, 1);
936 let step_input = narrowed.reshape(&[batch_size, narrowed.data().numel() / batch_size]);
937
938 let mut layer_input = step_input;
939
940 for (layer_idx, cell) in self.cells.iter().enumerate() {
941 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
942 hidden_states[layer_idx] = new_hidden.clone();
943 layer_input = new_hidden;
944 }
945 }
946
947 hidden_states.pop().unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
949 }
950
951 fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, seq_len: usize) -> Variable {
955 if outputs.is_empty() {
956 return Variable::new(
957 zeros(&[batch_size, 0, self.hidden_size]),
958 false,
959 );
960 }
961
962 let output_shape = [batch_size, seq_len, self.hidden_size];
963 let requires_grad = outputs.iter().any(|o| o.requires_grad());
964
965 let mut stacked_data = vec![0.0f32; batch_size * seq_len * self.hidden_size];
966 for (t, out) in outputs.iter().enumerate() {
967 let out_data = out.data().to_vec();
968 for b in 0..batch_size {
969 for h in 0..self.hidden_size {
970 let idx = b * seq_len * self.hidden_size + t * self.hidden_size + h;
971 stacked_data[idx] = out_data[b * self.hidden_size + h];
972 }
973 }
974 }
975
976 Variable::new(
977 Tensor::from_vec(stacked_data, &output_shape).unwrap(),
978 requires_grad,
979 )
980 }
981}
982
983#[cfg(test)]
988mod tests {
989 use super::*;
990
991 #[test]
992 fn test_rnn_cell() {
993 let cell = RNNCell::new(10, 20);
994 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
995 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
996 let output = cell.forward_step(&input, &hidden);
997 assert_eq!(output.shape(), vec![2, 20]);
998 }
999
1000 #[test]
1001 fn test_rnn() {
1002 let rnn = RNN::new(10, 20, 2);
1003 let input = Variable::new(
1004 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1005 false,
1006 );
1007 let output = rnn.forward(&input);
1008 assert_eq!(output.shape(), vec![2, 5, 20]);
1009 }
1010
1011 #[test]
1012 fn test_lstm() {
1013 let lstm = LSTM::new(10, 20, 1);
1014 let input = Variable::new(
1015 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1016 false,
1017 );
1018 let output = lstm.forward(&input);
1019 assert_eq!(output.shape(), vec![2, 5, 20]);
1020 }
1021}