1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::init::{xavier_uniform, zeros};
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25pub struct RNNCell {
33 pub weight_ih: Parameter,
35 pub weight_hh: Parameter,
37 pub bias_ih: Parameter,
39 pub bias_hh: Parameter,
41 input_size: usize,
43 hidden_size: usize,
45}
46
47impl RNNCell {
48 pub fn new(input_size: usize, hidden_size: usize) -> Self {
50 Self {
51 weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
52 weight_hh: Parameter::named(
53 "weight_hh",
54 xavier_uniform(hidden_size, hidden_size),
55 true,
56 ),
57 bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
58 bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
59 input_size,
60 hidden_size,
61 }
62 }
63
64 pub fn input_size(&self) -> usize {
66 self.input_size
67 }
68
69 pub fn hidden_size(&self) -> usize {
71 self.hidden_size
72 }
73
74 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
76 let input_features = input.data().shape().last().copied().unwrap_or(0);
77 assert_eq!(
78 input_features, self.input_size,
79 "RNNCell: expected input size {}, got {}",
80 self.input_size, input_features
81 );
82 let weight_ih = self.weight_ih.variable();
84 let weight_ih_t = weight_ih.transpose(0, 1);
85 let ih = input.matmul(&weight_ih_t);
86 let bias_ih = self.bias_ih.variable();
87 let ih = ih.add_var(&bias_ih);
88
89 let weight_hh = self.weight_hh.variable();
91 let weight_hh_t = weight_hh.transpose(0, 1);
92 let hh = hidden.matmul(&weight_hh_t);
93 let bias_hh = self.bias_hh.variable();
94 let hh = hh.add_var(&bias_hh);
95
96 ih.add_var(&hh).tanh()
98 }
99}
100
101impl Module for RNNCell {
102 fn forward(&self, input: &Variable) -> Variable {
103 let batch_size = input.shape()[0];
105 let hidden = Variable::new(
106 zeros(&[batch_size, self.hidden_size]),
107 input.requires_grad(),
108 );
109 self.forward_step(input, &hidden)
110 }
111
112 fn parameters(&self) -> Vec<Parameter> {
113 vec![
114 self.weight_ih.clone(),
115 self.weight_hh.clone(),
116 self.bias_ih.clone(),
117 self.bias_hh.clone(),
118 ]
119 }
120
121 fn named_parameters(&self) -> HashMap<String, Parameter> {
122 let mut params = HashMap::new();
123 params.insert("weight_ih".to_string(), self.weight_ih.clone());
124 params.insert("weight_hh".to_string(), self.weight_hh.clone());
125 params.insert("bias_ih".to_string(), self.bias_ih.clone());
126 params.insert("bias_hh".to_string(), self.bias_hh.clone());
127 params
128 }
129
130 fn name(&self) -> &'static str {
131 "RNNCell"
132 }
133}
134
135pub struct RNN {
143 cells: Vec<RNNCell>,
145 _input_size: usize,
147 hidden_size: usize,
149 num_layers: usize,
151 batch_first: bool,
153}
154
155impl RNN {
156 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
158 Self::with_options(input_size, hidden_size, num_layers, true)
159 }
160
161 pub fn with_options(
163 input_size: usize,
164 hidden_size: usize,
165 num_layers: usize,
166 batch_first: bool,
167 ) -> Self {
168 let mut cells = Vec::with_capacity(num_layers);
169
170 cells.push(RNNCell::new(input_size, hidden_size));
172
173 for _ in 1..num_layers {
175 cells.push(RNNCell::new(hidden_size, hidden_size));
176 }
177
178 Self {
179 cells,
180 _input_size: input_size,
181 hidden_size,
182 num_layers,
183 batch_first,
184 }
185 }
186}
187
188impl Module for RNN {
189 fn forward(&self, input: &Variable) -> Variable {
190 let shape = input.shape();
191 let (batch_size, seq_len, input_features) = if self.batch_first {
192 (shape[0], shape[1], shape[2])
193 } else {
194 (shape[1], shape[0], shape[2])
195 };
196
197 let mut hiddens: Vec<Variable> = (0..self.num_layers)
199 .map(|_| {
200 Variable::new(
201 zeros(&[batch_size, self.hidden_size]),
202 input.requires_grad(),
203 )
204 })
205 .collect();
206
207 let cell0 = &self.cells[0];
209 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
210 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
211 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
212 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, self.hidden_size]);
213
214 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
216 let bias_hh_0 = cell0.bias_hh.variable();
217
218 let mut outputs = Vec::with_capacity(seq_len);
219
220 for t in 0..seq_len {
221 let ih_t = ih_all_3d.select(1, t);
223 let hh = hiddens[0].matmul(&w_hh_t_0).add_var(&bias_hh_0);
224 hiddens[0] = ih_t.add_var(&hh).tanh();
225
226 for l in 1..self.num_layers {
228 let layer_input = hiddens[l - 1].clone();
229 hiddens[l] = self.cells[l].forward_step(&layer_input, &hiddens[l]);
230 }
231
232 outputs.push(hiddens[self.num_layers - 1].clone());
233 }
234
235 let time_dim = usize::from(self.batch_first);
237 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
238 let refs: Vec<&Variable> = unsqueezed.iter().collect();
239 Variable::cat(&refs, time_dim)
240 }
241
242 fn parameters(&self) -> Vec<Parameter> {
243 self.cells.iter().flat_map(|c| c.parameters()).collect()
244 }
245
246 fn name(&self) -> &'static str {
247 "RNN"
248 }
249}
250
251pub struct LSTMCell {
257 pub weight_ih: Parameter,
259 pub weight_hh: Parameter,
261 pub bias_ih: Parameter,
263 pub bias_hh: Parameter,
265 input_size: usize,
267 hidden_size: usize,
269}
270
271impl LSTMCell {
272 pub fn new(input_size: usize, hidden_size: usize) -> Self {
274 Self {
276 weight_ih: Parameter::named(
277 "weight_ih",
278 xavier_uniform(input_size, 4 * hidden_size),
279 true,
280 ),
281 weight_hh: Parameter::named(
282 "weight_hh",
283 xavier_uniform(hidden_size, 4 * hidden_size),
284 true,
285 ),
286 bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
287 bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
288 input_size,
289 hidden_size,
290 }
291 }
292
293 pub fn input_size(&self) -> usize {
295 self.input_size
296 }
297
298 pub fn hidden_size(&self) -> usize {
300 self.hidden_size
301 }
302
303 pub fn forward_step(
305 &self,
306 input: &Variable,
307 hx: &(Variable, Variable),
308 ) -> (Variable, Variable) {
309 let input_features = input.data().shape().last().copied().unwrap_or(0);
310 assert_eq!(
311 input_features, self.input_size,
312 "LSTMCell: expected input size {}, got {}",
313 self.input_size, input_features
314 );
315
316 let (h, c) = hx;
317
318 let weight_ih = self.weight_ih.variable();
320 let weight_ih_t = weight_ih.transpose(0, 1);
321 let ih = input.matmul(&weight_ih_t);
322 let bias_ih = self.bias_ih.variable();
323 let ih = ih.add_var(&bias_ih);
324
325 let weight_hh = self.weight_hh.variable();
326 let weight_hh_t = weight_hh.transpose(0, 1);
327 let hh = h.matmul(&weight_hh_t);
328 let bias_hh = self.bias_hh.variable();
329 let hh = hh.add_var(&bias_hh);
330
331 let gates = ih.add_var(&hh);
332 let hs = self.hidden_size;
333
334 let i = gates.narrow(1, 0, hs).sigmoid();
336 let f = gates.narrow(1, hs, hs).sigmoid();
337 let g = gates.narrow(1, 2 * hs, hs).tanh();
338 let o = gates.narrow(1, 3 * hs, hs).sigmoid();
339
340 let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
342
343 let h_new = o.mul_var(&c_new.tanh());
345
346 (h_new, c_new)
347 }
348}
349
350impl Module for LSTMCell {
351 fn forward(&self, input: &Variable) -> Variable {
352 let batch_size = input.shape()[0];
353 let h = Variable::new(
354 zeros(&[batch_size, self.hidden_size]),
355 input.requires_grad(),
356 );
357 let c = Variable::new(
358 zeros(&[batch_size, self.hidden_size]),
359 input.requires_grad(),
360 );
361 let (h_new, _) = self.forward_step(input, &(h, c));
362 h_new
363 }
364
365 fn parameters(&self) -> Vec<Parameter> {
366 vec![
367 self.weight_ih.clone(),
368 self.weight_hh.clone(),
369 self.bias_ih.clone(),
370 self.bias_hh.clone(),
371 ]
372 }
373
374 fn named_parameters(&self) -> HashMap<String, Parameter> {
375 let mut params = HashMap::new();
376 params.insert("weight_ih".to_string(), self.weight_ih.clone());
377 params.insert("weight_hh".to_string(), self.weight_hh.clone());
378 params.insert("bias_ih".to_string(), self.bias_ih.clone());
379 params.insert("bias_hh".to_string(), self.bias_hh.clone());
380 params
381 }
382
383 fn name(&self) -> &'static str {
384 "LSTMCell"
385 }
386}
387
388pub struct LSTM {
394 cells: Vec<LSTMCell>,
396 input_size: usize,
398 hidden_size: usize,
400 num_layers: usize,
402 batch_first: bool,
404}
405
406impl LSTM {
407 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
409 Self::with_options(input_size, hidden_size, num_layers, true)
410 }
411
412 pub fn with_options(
414 input_size: usize,
415 hidden_size: usize,
416 num_layers: usize,
417 batch_first: bool,
418 ) -> Self {
419 let mut cells = Vec::with_capacity(num_layers);
420 cells.push(LSTMCell::new(input_size, hidden_size));
421 for _ in 1..num_layers {
422 cells.push(LSTMCell::new(hidden_size, hidden_size));
423 }
424
425 Self {
426 cells,
427 input_size,
428 hidden_size,
429 num_layers,
430 batch_first,
431 }
432 }
433
434 pub fn input_size(&self) -> usize {
436 self.input_size
437 }
438
439 pub fn hidden_size(&self) -> usize {
441 self.hidden_size
442 }
443
444 pub fn num_layers(&self) -> usize {
446 self.num_layers
447 }
448}
449
450impl Module for LSTM {
451 fn forward(&self, input: &Variable) -> Variable {
452 let shape = input.shape();
453 let (batch_size, seq_len, input_features) = if self.batch_first {
454 (shape[0], shape[1], shape[2])
455 } else {
456 (shape[1], shape[0], shape[2])
457 };
458
459 let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
460 .map(|_| {
461 (
462 Variable::new(
463 zeros(&[batch_size, self.hidden_size]),
464 input.requires_grad(),
465 ),
466 Variable::new(
467 zeros(&[batch_size, self.hidden_size]),
468 input.requires_grad(),
469 ),
470 )
471 })
472 .collect();
473
474 let cell0 = &self.cells[0];
479 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
480 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
481 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
482 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 4 * self.hidden_size]);
484
485 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
487 let bias_hh_0 = cell0.bias_hh.variable();
488
489 let mut outputs = Vec::with_capacity(seq_len);
490
491 #[cfg(feature = "cuda")]
493 let on_gpu = input.data().device().is_gpu();
494 #[cfg(not(feature = "cuda"))]
495 let on_gpu = false;
496
497 for t in 0..seq_len {
498 let ih_t = ih_all_3d.select(1, t);
500 let (h, c) = &states[0];
501
502 let hh = h.matmul(&w_hh_t_0).add_var(&bias_hh_0);
504
505 let gates = ih_t.add_var(&hh);
507
508 if on_gpu {
509 #[cfg(feature = "cuda")]
512 {
513 let hs = self.hidden_size;
514 let gates_data = gates.data();
515 let c_data = c.data();
516
517 if let Some((h_tensor, c_tensor)) = gates_data.lstm_gates_fused(&c_data, hs) {
518 let saved_gates = gates_data.clone();
520 let saved_c_prev = c_data.clone();
521 let saved_c_new = c_tensor.clone();
522
523 let backward_fn = axonml_autograd::LstmGatesBackward::new(
525 gates.grad_fn().cloned(),
526 c.grad_fn().cloned(),
527 saved_gates,
528 saved_c_prev,
529 saved_c_new,
530 hs,
531 );
532 let grad_fn = axonml_autograd::GradFn::new(backward_fn);
533
534 let h_new = Variable::from_operation(
535 h_tensor,
536 grad_fn.clone(),
537 input.requires_grad(),
538 );
539 let c_new =
540 Variable::from_operation(c_tensor, grad_fn, input.requires_grad());
541 states[0] = (h_new, c_new);
542 }
543 }
544 } else {
545 let hs = self.hidden_size;
547 let i_gate = gates.narrow(1, 0, hs).sigmoid();
548 let f_gate = gates.narrow(1, hs, hs).sigmoid();
549 let g_gate = gates.narrow(1, 2 * hs, hs).tanh();
550 let o_gate = gates.narrow(1, 3 * hs, hs).sigmoid();
551 let c_new = f_gate.mul_var(c).add_var(&i_gate.mul_var(&g_gate));
552 let h_new = o_gate.mul_var(&c_new.tanh());
553 states[0] = (h_new, c_new);
554 }
555
556 for l in 1..self.num_layers {
558 let layer_input = states[l - 1].0.clone();
559 states[l] = self.cells[l].forward_step(&layer_input, &states[l]);
560 }
561
562 outputs.push(states[self.num_layers - 1].0.clone());
563 }
564
565 let time_dim = usize::from(self.batch_first);
567 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
568 let refs: Vec<&Variable> = unsqueezed.iter().collect();
569 Variable::cat(&refs, time_dim)
570 }
571
572 fn parameters(&self) -> Vec<Parameter> {
573 self.cells.iter().flat_map(|c| c.parameters()).collect()
574 }
575
576 fn named_parameters(&self) -> HashMap<String, Parameter> {
577 let mut params = HashMap::new();
578 if self.cells.len() == 1 {
579 for (n, p) in self.cells[0].named_parameters() {
581 params.insert(n, p);
582 }
583 } else {
584 for (i, cell) in self.cells.iter().enumerate() {
585 for (n, p) in cell.named_parameters() {
586 params.insert(format!("cells.{i}.{n}"), p);
587 }
588 }
589 }
590 params
591 }
592
593 fn name(&self) -> &'static str {
594 "LSTM"
595 }
596}
597
598pub struct GRUCell {
610 pub weight_ih: Parameter,
612 pub weight_hh: Parameter,
614 pub bias_ih: Parameter,
616 pub bias_hh: Parameter,
618 input_size: usize,
620 hidden_size: usize,
622}
623
624impl GRUCell {
625 pub fn new(input_size: usize, hidden_size: usize) -> Self {
627 Self {
628 weight_ih: Parameter::named(
629 "weight_ih",
630 xavier_uniform(input_size, 3 * hidden_size),
631 true,
632 ),
633 weight_hh: Parameter::named(
634 "weight_hh",
635 xavier_uniform(hidden_size, 3 * hidden_size),
636 true,
637 ),
638 bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
639 bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
640 input_size,
641 hidden_size,
642 }
643 }
644
645 pub fn input_size(&self) -> usize {
647 self.input_size
648 }
649
650 pub fn hidden_size(&self) -> usize {
652 self.hidden_size
653 }
654}
655
656impl GRUCell {
657 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
667 let _batch_size = input.shape()[0];
668 let hidden_size = self.hidden_size;
669
670 let weight_ih = self.weight_ih.variable();
672 let weight_hh = self.weight_hh.variable();
673 let bias_ih = self.bias_ih.variable();
674 let bias_hh = self.bias_hh.variable();
675
676 let weight_ih_t = weight_ih.transpose(0, 1);
679 let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
680
681 let weight_hh_t = weight_hh.transpose(0, 1);
684 let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
685
686 let ih_r = ih.narrow(1, 0, hidden_size);
689 let ih_z = ih.narrow(1, hidden_size, hidden_size);
690 let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
691
692 let hh_r = hh.narrow(1, 0, hidden_size);
693 let hh_z = hh.narrow(1, hidden_size, hidden_size);
694 let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
695
696 let r = ih_r.add_var(&hh_r).sigmoid();
699
700 let z = ih_z.add_var(&hh_z).sigmoid();
702
703 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
705
706 let h_minus_n = hidden.sub_var(&n);
709 n.add_var(&z.mul_var(&h_minus_n))
710 }
711}
712
713impl Module for GRUCell {
714 fn forward(&self, input: &Variable) -> Variable {
715 let batch_size = input.shape()[0];
716
717 let hidden = Variable::new(
719 zeros(&[batch_size, self.hidden_size]),
720 input.requires_grad(),
721 );
722
723 self.forward_step(input, &hidden)
724 }
725
726 fn parameters(&self) -> Vec<Parameter> {
727 vec![
728 self.weight_ih.clone(),
729 self.weight_hh.clone(),
730 self.bias_ih.clone(),
731 self.bias_hh.clone(),
732 ]
733 }
734
735 fn named_parameters(&self) -> HashMap<String, Parameter> {
736 let mut params = HashMap::new();
737 params.insert("weight_ih".to_string(), self.weight_ih.clone());
738 params.insert("weight_hh".to_string(), self.weight_hh.clone());
739 params.insert("bias_ih".to_string(), self.bias_ih.clone());
740 params.insert("bias_hh".to_string(), self.bias_hh.clone());
741 params
742 }
743
744 fn name(&self) -> &'static str {
745 "GRUCell"
746 }
747}
748
749pub struct GRU {
751 cells: Vec<GRUCell>,
753 hidden_size: usize,
755 num_layers: usize,
757 batch_first: bool,
759}
760
761impl GRU {
762 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
764 let mut cells = Vec::with_capacity(num_layers);
765 cells.push(GRUCell::new(input_size, hidden_size));
766 for _ in 1..num_layers {
767 cells.push(GRUCell::new(hidden_size, hidden_size));
768 }
769 Self {
770 cells,
771 hidden_size,
772 num_layers,
773 batch_first: true,
774 }
775 }
776
777 pub fn hidden_size(&self) -> usize {
779 self.hidden_size
780 }
781
782 pub fn num_layers(&self) -> usize {
784 self.num_layers
785 }
786}
787
788impl Module for GRU {
789 fn forward(&self, input: &Variable) -> Variable {
790 let shape = input.shape();
791 let (batch_size, seq_len, input_features) = if self.batch_first {
792 (shape[0], shape[1], shape[2])
793 } else {
794 (shape[1], shape[0], shape[2])
795 };
796
797 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
799 .map(|_| {
800 Variable::new(
801 zeros(&[batch_size, self.hidden_size]),
802 input.requires_grad(),
803 )
804 })
805 .collect();
806
807 let cell0 = &self.cells[0];
810 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
811 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
812 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
813 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
814
815 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
817 let bias_hh_0 = cell0.bias_hh.variable();
818
819 let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
820
821 #[cfg(feature = "cuda")]
823 let on_gpu = input.data().device().is_gpu();
824 #[cfg(not(feature = "cuda"))]
825 let on_gpu = false;
826
827 for t in 0..seq_len {
828 let ih_t = ih_all_3d.select(1, t);
830 let hidden = &hidden_states[0];
831 let hs = self.hidden_size;
832
833 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
834
835 if on_gpu {
836 #[cfg(feature = "cuda")]
839 {
840 let ih_data = ih_t.data();
841 let hh_data = hh.data();
842 let h_data = hidden.data();
843
844 if let Some(h_tensor) = ih_data.gru_gates_fused(&hh_data, &h_data, hs) {
845 let saved_ih = ih_data.clone();
847 let saved_hh = hh_data.clone();
848 let saved_h_prev = h_data.clone();
849
850 let backward_fn = axonml_autograd::GruGatesBackward::new(
852 ih_t.grad_fn().cloned(),
853 hh.grad_fn().cloned(),
854 hidden.grad_fn().cloned(),
855 saved_ih,
856 saved_hh,
857 saved_h_prev,
858 hs,
859 );
860 let grad_fn = axonml_autograd::GradFn::new(backward_fn);
861
862 let h_new =
863 Variable::from_operation(h_tensor, grad_fn, input.requires_grad());
864 hidden_states[0] = h_new;
865 }
866 }
867 } else {
868 let ih_r = ih_t.narrow(1, 0, hs);
870 let ih_z = ih_t.narrow(1, hs, hs);
871 let ih_n = ih_t.narrow(1, 2 * hs, hs);
872 let hh_r = hh.narrow(1, 0, hs);
873 let hh_z = hh.narrow(1, hs, hs);
874 let hh_n = hh.narrow(1, 2 * hs, hs);
875
876 let r = ih_r.add_var(&hh_r).sigmoid();
877 let z = ih_z.add_var(&hh_z).sigmoid();
878 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
879 let h_minus_n = hidden.sub_var(&n);
880 let h_new = n.add_var(&z.mul_var(&h_minus_n));
881 hidden_states[0] = h_new;
882 }
883
884 let mut layer_output = hidden_states[0].clone();
886 for l in 1..self.num_layers {
887 let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
888 hidden_states[l] = new_hidden.clone();
889 layer_output = new_hidden;
890 }
891
892 output_vars.push(layer_output);
893 }
894
895 self.stack_outputs(&output_vars, batch_size, seq_len)
897 }
898
899 fn parameters(&self) -> Vec<Parameter> {
900 self.cells.iter().flat_map(|c| c.parameters()).collect()
901 }
902
903 fn named_parameters(&self) -> HashMap<String, Parameter> {
904 let mut params = HashMap::new();
905 if self.cells.len() == 1 {
906 for (n, p) in self.cells[0].named_parameters() {
907 params.insert(n, p);
908 }
909 } else {
910 for (i, cell) in self.cells.iter().enumerate() {
911 for (n, p) in cell.named_parameters() {
912 params.insert(format!("cells.{i}.{n}"), p);
913 }
914 }
915 }
916 params
917 }
918
919 fn name(&self) -> &'static str {
920 "GRU"
921 }
922}
923
924impl GRU {
925 pub fn forward_mean(&self, input: &Variable) -> Variable {
928 let shape = input.shape();
929 let (batch_size, seq_len, input_features) = if self.batch_first {
930 (shape[0], shape[1], shape[2])
931 } else {
932 (shape[1], shape[0], shape[2])
933 };
934
935 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
936 .map(|_| {
937 Variable::new(
938 zeros(&[batch_size, self.hidden_size]),
939 input.requires_grad(),
940 )
941 })
942 .collect();
943
944 let cell0 = &self.cells[0];
946 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
947 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
948 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
949 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
950
951 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
953 let bias_hh_0 = cell0.bias_hh.variable();
954
955 let mut output_sum: Option<Variable> = None;
956 let hs = self.hidden_size;
957
958 for t in 0..seq_len {
959 let ih_t = ih_all_3d.select(1, t);
961 let hidden = &hidden_states[0];
962 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
963
964 let ih_r = ih_t.narrow(1, 0, hs);
965 let ih_z = ih_t.narrow(1, hs, hs);
966 let ih_n = ih_t.narrow(1, 2 * hs, hs);
967 let hh_r = hh.narrow(1, 0, hs);
968 let hh_z = hh.narrow(1, hs, hs);
969 let hh_n = hh.narrow(1, 2 * hs, hs);
970
971 let r = ih_r.add_var(&hh_r).sigmoid();
972 let z = ih_z.add_var(&hh_z).sigmoid();
973 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
974 let h_minus_n = hidden.sub_var(&n);
975 let h_new = n.add_var(&z.mul_var(&h_minus_n));
976 hidden_states[0] = h_new.clone();
977
978 let mut layer_output = h_new;
980 for l in 1..self.num_layers {
981 let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
982 hidden_states[l] = new_hidden.clone();
983 layer_output = new_hidden;
984 }
985
986 output_sum = Some(match output_sum {
987 None => layer_output,
988 Some(acc) => acc.add_var(&layer_output),
989 });
990 }
991
992 match output_sum {
993 Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
994 None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
995 }
996 }
997
998 pub fn forward_last(&self, input: &Variable) -> Variable {
1001 let shape = input.shape();
1002 let (batch_size, seq_len, input_features) = if self.batch_first {
1003 (shape[0], shape[1], shape[2])
1004 } else {
1005 (shape[1], shape[0], shape[2])
1006 };
1007
1008 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
1009 .map(|_| {
1010 Variable::new(
1011 zeros(&[batch_size, self.hidden_size]),
1012 input.requires_grad(),
1013 )
1014 })
1015 .collect();
1016
1017 let cell0 = &self.cells[0];
1019 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
1020 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
1021 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
1022 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
1023
1024 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
1026 let bias_hh_0 = cell0.bias_hh.variable();
1027 let hs = self.hidden_size;
1028
1029 for t in 0..seq_len {
1030 let ih_t = ih_all_3d.select(1, t);
1032 let hidden = &hidden_states[0];
1033 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
1034
1035 let ih_r = ih_t.narrow(1, 0, hs);
1036 let ih_z = ih_t.narrow(1, hs, hs);
1037 let ih_n = ih_t.narrow(1, 2 * hs, hs);
1038 let hh_r = hh.narrow(1, 0, hs);
1039 let hh_z = hh.narrow(1, hs, hs);
1040 let hh_n = hh.narrow(1, 2 * hs, hs);
1041
1042 let r = ih_r.add_var(&hh_r).sigmoid();
1043 let z = ih_z.add_var(&hh_z).sigmoid();
1044 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1045 let h_minus_n = hidden.sub_var(&n);
1046 let h_new = n.add_var(&z.mul_var(&h_minus_n));
1047 hidden_states[0] = h_new.clone();
1048
1049 let mut layer_input = h_new;
1051
1052 for (layer_idx, cell) in self.cells.iter().enumerate().skip(1) {
1053 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
1054 hidden_states[layer_idx] = new_hidden.clone();
1055 layer_input = new_hidden;
1056 }
1057 }
1058
1059 hidden_states
1061 .pop()
1062 .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
1063 }
1064
1065 fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
1069 if outputs.is_empty() {
1070 return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
1071 }
1072
1073 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(1)).collect();
1075 let refs: Vec<&Variable> = unsqueezed.iter().collect();
1076 Variable::cat(&refs, 1)
1077 }
1078}
1079
1080#[cfg(test)]
1085mod tests {
1086 use super::*;
1087 use axonml_tensor::Tensor;
1088
1089 #[test]
1090 fn test_rnn_cell() {
1091 let cell = RNNCell::new(10, 20);
1092 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
1093 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
1094 let output = cell.forward_step(&input, &hidden);
1095 assert_eq!(output.shape(), vec![2, 20]);
1096 }
1097
1098 #[test]
1099 fn test_rnn() {
1100 let rnn = RNN::new(10, 20, 2);
1101 let input = Variable::new(
1102 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1103 false,
1104 );
1105 let output = rnn.forward(&input);
1106 assert_eq!(output.shape(), vec![2, 5, 20]);
1107 }
1108
1109 #[test]
1110 fn test_lstm() {
1111 let lstm = LSTM::new(10, 20, 1);
1112 let input = Variable::new(
1113 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1114 false,
1115 );
1116 let output = lstm.forward(&input);
1117 assert_eq!(output.shape(), vec![2, 5, 20]);
1118 }
1119
1120 #[test]
1121 fn test_gru_gradients_reach_parameters() {
1122 let gru = GRU::new(4, 8, 1);
1123 let input = Variable::new(
1124 Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1125 true,
1126 );
1127 let output = gru.forward(&input);
1128 println!(
1129 "Output shape: {:?}, requires_grad: {}",
1130 output.shape(),
1131 output.requires_grad()
1132 );
1133 let loss = output.sum();
1134 println!(
1135 "Loss: {:?}, requires_grad: {}",
1136 loss.data().to_vec(),
1137 loss.requires_grad()
1138 );
1139 loss.backward();
1140
1141 println!(
1143 "Input grad: {:?}",
1144 input
1145 .grad()
1146 .map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>())
1147 );
1148
1149 let params = gru.parameters();
1150 println!("Number of parameters: {}", params.len());
1151 let mut has_grad = false;
1152 for (i, p) in params.iter().enumerate() {
1153 let grad = p.grad();
1154 match grad {
1155 Some(g) => {
1156 let gv = g.to_vec();
1157 let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1158 println!(
1159 "Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1160 i,
1161 p.shape(),
1162 p.requires_grad(),
1163 sum_abs
1164 );
1165 if sum_abs > 0.0 {
1166 has_grad = true;
1167 }
1168 }
1169 None => {
1170 println!(
1171 "Param {} shape {:?} requires_grad={}: NO GRADIENT",
1172 i,
1173 p.shape(),
1174 p.requires_grad()
1175 );
1176 }
1177 }
1178 }
1179 assert!(
1180 has_grad,
1181 "At least one GRU parameter should have non-zero gradients"
1182 );
1183 }
1184}