1use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28
29use crate::init::{xavier_uniform, zeros};
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33pub struct RNNCell {
41 pub weight_ih: Parameter,
43 pub weight_hh: Parameter,
45 pub bias_ih: Parameter,
47 pub bias_hh: Parameter,
49 input_size: usize,
51 hidden_size: usize,
53}
54
55impl RNNCell {
56 pub fn new(input_size: usize, hidden_size: usize) -> Self {
58 Self {
59 weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
60 weight_hh: Parameter::named(
61 "weight_hh",
62 xavier_uniform(hidden_size, hidden_size),
63 true,
64 ),
65 bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
66 bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
67 input_size,
68 hidden_size,
69 }
70 }
71
72 pub fn input_size(&self) -> usize {
74 self.input_size
75 }
76
77 pub fn hidden_size(&self) -> usize {
79 self.hidden_size
80 }
81
82 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
84 let input_features = input.data().shape().last().copied().unwrap_or(0);
85 assert_eq!(
86 input_features, self.input_size,
87 "RNNCell: expected input size {}, got {}",
88 self.input_size, input_features
89 );
90 let weight_ih = self.weight_ih.variable();
92 let weight_ih_t = weight_ih.transpose(0, 1);
93 let ih = input.matmul(&weight_ih_t);
94 let bias_ih = self.bias_ih.variable();
95 let ih = ih.add_var(&bias_ih);
96
97 let weight_hh = self.weight_hh.variable();
99 let weight_hh_t = weight_hh.transpose(0, 1);
100 let hh = hidden.matmul(&weight_hh_t);
101 let bias_hh = self.bias_hh.variable();
102 let hh = hh.add_var(&bias_hh);
103
104 ih.add_var(&hh).tanh()
106 }
107}
108
109impl Module for RNNCell {
110 fn forward(&self, input: &Variable) -> Variable {
111 let batch_size = input.shape()[0];
113 let hidden = Variable::new(
114 zeros(&[batch_size, self.hidden_size]),
115 input.requires_grad(),
116 );
117 self.forward_step(input, &hidden)
118 }
119
120 fn parameters(&self) -> Vec<Parameter> {
121 vec![
122 self.weight_ih.clone(),
123 self.weight_hh.clone(),
124 self.bias_ih.clone(),
125 self.bias_hh.clone(),
126 ]
127 }
128
129 fn named_parameters(&self) -> HashMap<String, Parameter> {
130 let mut params = HashMap::new();
131 params.insert("weight_ih".to_string(), self.weight_ih.clone());
132 params.insert("weight_hh".to_string(), self.weight_hh.clone());
133 params.insert("bias_ih".to_string(), self.bias_ih.clone());
134 params.insert("bias_hh".to_string(), self.bias_hh.clone());
135 params
136 }
137
138 fn name(&self) -> &'static str {
139 "RNNCell"
140 }
141}
142
143pub struct RNN {
151 cells: Vec<RNNCell>,
153 _input_size: usize,
155 hidden_size: usize,
157 num_layers: usize,
159 batch_first: bool,
161}
162
163impl RNN {
164 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
166 Self::with_options(input_size, hidden_size, num_layers, true)
167 }
168
169 pub fn with_options(
171 input_size: usize,
172 hidden_size: usize,
173 num_layers: usize,
174 batch_first: bool,
175 ) -> Self {
176 let mut cells = Vec::with_capacity(num_layers);
177
178 cells.push(RNNCell::new(input_size, hidden_size));
180
181 for _ in 1..num_layers {
183 cells.push(RNNCell::new(hidden_size, hidden_size));
184 }
185
186 Self {
187 cells,
188 _input_size: input_size,
189 hidden_size,
190 num_layers,
191 batch_first,
192 }
193 }
194}
195
196impl Module for RNN {
197 fn forward(&self, input: &Variable) -> Variable {
198 let shape = input.shape();
199 let (batch_size, seq_len, input_features) = if self.batch_first {
200 (shape[0], shape[1], shape[2])
201 } else {
202 (shape[1], shape[0], shape[2])
203 };
204
205 let mut hiddens: Vec<Variable> = (0..self.num_layers)
207 .map(|_| {
208 Variable::new(
209 zeros(&[batch_size, self.hidden_size]),
210 input.requires_grad(),
211 )
212 })
213 .collect();
214
215 let cell0 = &self.cells[0];
217 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
218 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
219 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
220 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, self.hidden_size]);
221
222 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
224 let bias_hh_0 = cell0.bias_hh.variable();
225
226 let mut outputs = Vec::with_capacity(seq_len);
227
228 for t in 0..seq_len {
229 let ih_t = ih_all_3d.select(1, t);
231 let hh = hiddens[0].matmul(&w_hh_t_0).add_var(&bias_hh_0);
232 hiddens[0] = ih_t.add_var(&hh).tanh();
233
234 for l in 1..self.num_layers {
236 let layer_input = hiddens[l - 1].clone();
237 hiddens[l] = self.cells[l].forward_step(&layer_input, &hiddens[l]);
238 }
239
240 outputs.push(hiddens[self.num_layers - 1].clone());
241 }
242
243 let time_dim = usize::from(self.batch_first);
245 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
246 let refs: Vec<&Variable> = unsqueezed.iter().collect();
247 Variable::cat(&refs, time_dim)
248 }
249
250 fn parameters(&self) -> Vec<Parameter> {
251 self.cells.iter().flat_map(|c| c.parameters()).collect()
252 }
253
254 fn name(&self) -> &'static str {
255 "RNN"
256 }
257}
258
259pub struct LSTMCell {
265 pub weight_ih: Parameter,
267 pub weight_hh: Parameter,
269 pub bias_ih: Parameter,
271 pub bias_hh: Parameter,
273 input_size: usize,
275 hidden_size: usize,
277}
278
279impl LSTMCell {
280 pub fn new(input_size: usize, hidden_size: usize) -> Self {
282 Self {
284 weight_ih: Parameter::named(
285 "weight_ih",
286 xavier_uniform(input_size, 4 * hidden_size),
287 true,
288 ),
289 weight_hh: Parameter::named(
290 "weight_hh",
291 xavier_uniform(hidden_size, 4 * hidden_size),
292 true,
293 ),
294 bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
295 bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
296 input_size,
297 hidden_size,
298 }
299 }
300
301 pub fn input_size(&self) -> usize {
303 self.input_size
304 }
305
306 pub fn hidden_size(&self) -> usize {
308 self.hidden_size
309 }
310
311 pub fn forward_step(
313 &self,
314 input: &Variable,
315 hx: &(Variable, Variable),
316 ) -> (Variable, Variable) {
317 let input_features = input.data().shape().last().copied().unwrap_or(0);
318 assert_eq!(
319 input_features, self.input_size,
320 "LSTMCell: expected input size {}, got {}",
321 self.input_size, input_features
322 );
323
324 let (h, c) = hx;
325
326 let weight_ih = self.weight_ih.variable();
328 let weight_ih_t = weight_ih.transpose(0, 1);
329 let ih = input.matmul(&weight_ih_t);
330 let bias_ih = self.bias_ih.variable();
331 let ih = ih.add_var(&bias_ih);
332
333 let weight_hh = self.weight_hh.variable();
334 let weight_hh_t = weight_hh.transpose(0, 1);
335 let hh = h.matmul(&weight_hh_t);
336 let bias_hh = self.bias_hh.variable();
337 let hh = hh.add_var(&bias_hh);
338
339 let gates = ih.add_var(&hh);
340 let hs = self.hidden_size;
341
342 let i = gates.narrow(1, 0, hs).sigmoid();
344 let f = gates.narrow(1, hs, hs).sigmoid();
345 let g = gates.narrow(1, 2 * hs, hs).tanh();
346 let o = gates.narrow(1, 3 * hs, hs).sigmoid();
347
348 let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
350
351 let h_new = o.mul_var(&c_new.tanh());
353
354 (h_new, c_new)
355 }
356}
357
358impl Module for LSTMCell {
359 fn forward(&self, input: &Variable) -> Variable {
360 let batch_size = input.shape()[0];
361 let h = Variable::new(
362 zeros(&[batch_size, self.hidden_size]),
363 input.requires_grad(),
364 );
365 let c = Variable::new(
366 zeros(&[batch_size, self.hidden_size]),
367 input.requires_grad(),
368 );
369 let (h_new, _) = self.forward_step(input, &(h, c));
370 h_new
371 }
372
373 fn parameters(&self) -> Vec<Parameter> {
374 vec![
375 self.weight_ih.clone(),
376 self.weight_hh.clone(),
377 self.bias_ih.clone(),
378 self.bias_hh.clone(),
379 ]
380 }
381
382 fn named_parameters(&self) -> HashMap<String, Parameter> {
383 let mut params = HashMap::new();
384 params.insert("weight_ih".to_string(), self.weight_ih.clone());
385 params.insert("weight_hh".to_string(), self.weight_hh.clone());
386 params.insert("bias_ih".to_string(), self.bias_ih.clone());
387 params.insert("bias_hh".to_string(), self.bias_hh.clone());
388 params
389 }
390
391 fn name(&self) -> &'static str {
392 "LSTMCell"
393 }
394}
395
396pub struct LSTM {
402 cells: Vec<LSTMCell>,
404 input_size: usize,
406 hidden_size: usize,
408 num_layers: usize,
410 batch_first: bool,
412}
413
414impl LSTM {
415 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
417 Self::with_options(input_size, hidden_size, num_layers, true)
418 }
419
420 pub fn with_options(
422 input_size: usize,
423 hidden_size: usize,
424 num_layers: usize,
425 batch_first: bool,
426 ) -> Self {
427 let mut cells = Vec::with_capacity(num_layers);
428 cells.push(LSTMCell::new(input_size, hidden_size));
429 for _ in 1..num_layers {
430 cells.push(LSTMCell::new(hidden_size, hidden_size));
431 }
432
433 Self {
434 cells,
435 input_size,
436 hidden_size,
437 num_layers,
438 batch_first,
439 }
440 }
441
442 pub fn input_size(&self) -> usize {
444 self.input_size
445 }
446
447 pub fn hidden_size(&self) -> usize {
449 self.hidden_size
450 }
451
452 pub fn num_layers(&self) -> usize {
454 self.num_layers
455 }
456}
457
458impl Module for LSTM {
459 fn forward(&self, input: &Variable) -> Variable {
460 let shape = input.shape();
461 let (batch_size, seq_len, input_features) = if self.batch_first {
462 (shape[0], shape[1], shape[2])
463 } else {
464 (shape[1], shape[0], shape[2])
465 };
466
467 let lstm_input_device = input.data().device();
468 #[cfg(feature = "cuda")]
469 let lstm_on_gpu = lstm_input_device.is_gpu();
470 #[cfg(not(feature = "cuda"))]
471 let lstm_on_gpu = false;
472
473 let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
474 .map(|_| {
475 let make_h = || {
476 let h_cpu = zeros(&[batch_size, self.hidden_size]);
477 let h_tensor = if lstm_on_gpu {
478 h_cpu
479 .to_device(lstm_input_device)
480 .expect("LSTM: failed to move hidden state to GPU")
481 } else {
482 h_cpu
483 };
484 Variable::new(h_tensor, input.requires_grad())
485 };
486 (make_h(), make_h())
487 })
488 .collect();
489
490 let cell0 = &self.cells[0];
495 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
496 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
497 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
498 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 4 * self.hidden_size]);
500
501 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
503 let bias_hh_0 = cell0.bias_hh.variable();
504
505 let mut outputs = Vec::with_capacity(seq_len);
506
507 #[cfg(feature = "cuda")]
509 let on_gpu = input.data().device().is_gpu();
510 #[cfg(not(feature = "cuda"))]
511 let on_gpu = false;
512
513 for t in 0..seq_len {
514 let ih_t = ih_all_3d.select(1, t);
516 let (h, c) = &states[0];
517
518 let hh = h.matmul(&w_hh_t_0).add_var(&bias_hh_0);
520
521 let gates = ih_t.add_var(&hh);
523
524 if on_gpu {
525 #[cfg(feature = "cuda")]
528 {
529 let hs = self.hidden_size;
530 let gates_data = gates.data();
531 let c_data = c.data();
532
533 if let Some((h_tensor, c_tensor)) = gates_data.lstm_gates_fused(&c_data, hs) {
534 let saved_gates = gates_data.clone();
536 let saved_c_prev = c_data.clone();
537 let saved_c_new = c_tensor.clone();
538
539 let backward_fn = axonml_autograd::LstmGatesBackward::new(
541 gates.grad_fn().cloned(),
542 c.grad_fn().cloned(),
543 saved_gates,
544 saved_c_prev,
545 saved_c_new,
546 hs,
547 );
548 let grad_fn = axonml_autograd::GradFn::new(backward_fn);
549
550 let fused_requires_grad = gates.requires_grad() || c.requires_grad();
551 let h_new = Variable::from_operation(
552 h_tensor,
553 grad_fn.clone(),
554 fused_requires_grad,
555 );
556 let c_new =
557 Variable::from_operation(c_tensor, grad_fn, fused_requires_grad);
558 states[0] = (h_new, c_new);
559 }
560 }
561 } else {
562 let hs = self.hidden_size;
564 let i_gate = gates.narrow(1, 0, hs).sigmoid();
565 let f_gate = gates.narrow(1, hs, hs).sigmoid();
566 let g_gate = gates.narrow(1, 2 * hs, hs).tanh();
567 let o_gate = gates.narrow(1, 3 * hs, hs).sigmoid();
568 let c_new = f_gate.mul_var(c).add_var(&i_gate.mul_var(&g_gate));
569 let h_new = o_gate.mul_var(&c_new.tanh());
570 states[0] = (h_new, c_new);
571 }
572
573 for l in 1..self.num_layers {
575 let layer_input = states[l - 1].0.clone();
576 states[l] = self.cells[l].forward_step(&layer_input, &states[l]);
577 }
578
579 outputs.push(states[self.num_layers - 1].0.clone());
580 }
581
582 let time_dim = usize::from(self.batch_first);
584 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
585 let refs: Vec<&Variable> = unsqueezed.iter().collect();
586 Variable::cat(&refs, time_dim)
587 }
588
589 fn parameters(&self) -> Vec<Parameter> {
590 self.cells.iter().flat_map(|c| c.parameters()).collect()
591 }
592
593 fn named_parameters(&self) -> HashMap<String, Parameter> {
594 let mut params = HashMap::new();
595 if self.cells.len() == 1 {
596 for (n, p) in self.cells[0].named_parameters() {
598 params.insert(n, p);
599 }
600 } else {
601 for (i, cell) in self.cells.iter().enumerate() {
602 for (n, p) in cell.named_parameters() {
603 params.insert(format!("cells.{i}.{n}"), p);
604 }
605 }
606 }
607 params
608 }
609
610 fn name(&self) -> &'static str {
611 "LSTM"
612 }
613}
614
615pub struct GRUCell {
627 pub weight_ih: Parameter,
629 pub weight_hh: Parameter,
631 pub bias_ih: Parameter,
633 pub bias_hh: Parameter,
635 input_size: usize,
637 hidden_size: usize,
639}
640
641impl GRUCell {
642 pub fn new(input_size: usize, hidden_size: usize) -> Self {
644 Self {
645 weight_ih: Parameter::named(
646 "weight_ih",
647 xavier_uniform(input_size, 3 * hidden_size),
648 true,
649 ),
650 weight_hh: Parameter::named(
651 "weight_hh",
652 xavier_uniform(hidden_size, 3 * hidden_size),
653 true,
654 ),
655 bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
656 bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
657 input_size,
658 hidden_size,
659 }
660 }
661
662 pub fn input_size(&self) -> usize {
664 self.input_size
665 }
666
667 pub fn hidden_size(&self) -> usize {
669 self.hidden_size
670 }
671}
672
673impl GRUCell {
674 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
684 let _batch_size = input.shape()[0];
685 let hidden_size = self.hidden_size;
686
687 let weight_ih = self.weight_ih.variable();
689 let weight_hh = self.weight_hh.variable();
690 let bias_ih = self.bias_ih.variable();
691 let bias_hh = self.bias_hh.variable();
692
693 let weight_ih_t = weight_ih.transpose(0, 1);
696 let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
697
698 let weight_hh_t = weight_hh.transpose(0, 1);
701 let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
702
703 let ih_r = ih.narrow(1, 0, hidden_size);
706 let ih_z = ih.narrow(1, hidden_size, hidden_size);
707 let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
708
709 let hh_r = hh.narrow(1, 0, hidden_size);
710 let hh_z = hh.narrow(1, hidden_size, hidden_size);
711 let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
712
713 let r = ih_r.add_var(&hh_r).sigmoid();
716
717 let z = ih_z.add_var(&hh_z).sigmoid();
719
720 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
722
723 let h_minus_n = hidden.sub_var(&n);
726 n.add_var(&z.mul_var(&h_minus_n))
727 }
728}
729
730impl Module for GRUCell {
731 fn forward(&self, input: &Variable) -> Variable {
732 let batch_size = input.shape()[0];
733
734 let hidden = Variable::new(
736 zeros(&[batch_size, self.hidden_size]),
737 input.requires_grad(),
738 );
739
740 self.forward_step(input, &hidden)
741 }
742
743 fn parameters(&self) -> Vec<Parameter> {
744 vec![
745 self.weight_ih.clone(),
746 self.weight_hh.clone(),
747 self.bias_ih.clone(),
748 self.bias_hh.clone(),
749 ]
750 }
751
752 fn named_parameters(&self) -> HashMap<String, Parameter> {
753 let mut params = HashMap::new();
754 params.insert("weight_ih".to_string(), self.weight_ih.clone());
755 params.insert("weight_hh".to_string(), self.weight_hh.clone());
756 params.insert("bias_ih".to_string(), self.bias_ih.clone());
757 params.insert("bias_hh".to_string(), self.bias_hh.clone());
758 params
759 }
760
761 fn name(&self) -> &'static str {
762 "GRUCell"
763 }
764}
765
766pub struct GRU {
768 cells: Vec<GRUCell>,
770 hidden_size: usize,
772 num_layers: usize,
774 batch_first: bool,
776}
777
778impl GRU {
779 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
781 let mut cells = Vec::with_capacity(num_layers);
782 cells.push(GRUCell::new(input_size, hidden_size));
783 for _ in 1..num_layers {
784 cells.push(GRUCell::new(hidden_size, hidden_size));
785 }
786 Self {
787 cells,
788 hidden_size,
789 num_layers,
790 batch_first: true,
791 }
792 }
793
794 pub fn hidden_size(&self) -> usize {
796 self.hidden_size
797 }
798
799 pub fn num_layers(&self) -> usize {
801 self.num_layers
802 }
803}
804
805impl Module for GRU {
806 fn forward(&self, input: &Variable) -> Variable {
807 let shape = input.shape();
808 let (batch_size, seq_len, input_features) = if self.batch_first {
809 (shape[0], shape[1], shape[2])
810 } else {
811 (shape[1], shape[0], shape[2])
812 };
813
814 #[cfg(feature = "cuda")]
816 let on_gpu = input.data().device().is_gpu();
817 #[cfg(not(feature = "cuda"))]
818 let on_gpu = false;
819
820 let input_device = input.data().device();
821
822 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
825 .map(|_| {
826 let h_cpu = zeros(&[batch_size, self.hidden_size]);
827 let h_tensor = if on_gpu {
828 h_cpu
829 .to_device(input_device)
830 .expect("GRU: failed to move hidden state to GPU")
831 } else {
832 h_cpu
833 };
834 Variable::new(h_tensor, input.requires_grad())
835 })
836 .collect();
837
838 let cell0 = &self.cells[0];
841 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
842 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
843 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
844 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
845
846 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
848 let bias_hh_0 = cell0.bias_hh.variable();
849
850 let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
851
852 for t in 0..seq_len {
853 let ih_t = ih_all_3d.select(1, t);
855 let hidden = &hidden_states[0];
856 let hs = self.hidden_size;
857
858 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
859
860 if on_gpu {
861 #[cfg(feature = "cuda")]
864 {
865 let ih_data = ih_t.data();
866 let hh_data = hh.data();
867 let h_data = hidden.data();
868
869 if let Some(h_tensor) = ih_data.gru_gates_fused(&hh_data, &h_data, hs) {
870 let saved_ih = ih_data.clone();
872 let saved_hh = hh_data.clone();
873 let saved_h_prev = h_data.clone();
874
875 let backward_fn = axonml_autograd::GruGatesBackward::new(
877 ih_t.grad_fn().cloned(),
878 hh.grad_fn().cloned(),
879 hidden.grad_fn().cloned(),
880 saved_ih,
881 saved_hh,
882 saved_h_prev,
883 hs,
884 );
885 let grad_fn = axonml_autograd::GradFn::new(backward_fn);
886
887 let fused_requires_grad =
893 ih_t.requires_grad() || hh.requires_grad() || hidden.requires_grad();
894 let h_new =
895 Variable::from_operation(h_tensor, grad_fn, fused_requires_grad);
896 hidden_states[0] = h_new;
897 }
898 }
899 } else {
900 let ih_r = ih_t.narrow(1, 0, hs);
902 let ih_z = ih_t.narrow(1, hs, hs);
903 let ih_n = ih_t.narrow(1, 2 * hs, hs);
904 let hh_r = hh.narrow(1, 0, hs);
905 let hh_z = hh.narrow(1, hs, hs);
906 let hh_n = hh.narrow(1, 2 * hs, hs);
907
908 let r = ih_r.add_var(&hh_r).sigmoid();
909 let z = ih_z.add_var(&hh_z).sigmoid();
910 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
911 let h_minus_n = hidden.sub_var(&n);
912 let h_new = n.add_var(&z.mul_var(&h_minus_n));
913 hidden_states[0] = h_new;
914 }
915
916 let mut layer_output = hidden_states[0].clone();
918 for l in 1..self.num_layers {
919 let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
920 hidden_states[l] = new_hidden.clone();
921 layer_output = new_hidden;
922 }
923
924 output_vars.push(layer_output);
925 }
926
927 self.stack_outputs(&output_vars, batch_size, seq_len)
929 }
930
931 fn parameters(&self) -> Vec<Parameter> {
932 self.cells.iter().flat_map(|c| c.parameters()).collect()
933 }
934
935 fn named_parameters(&self) -> HashMap<String, Parameter> {
936 let mut params = HashMap::new();
937 if self.cells.len() == 1 {
938 for (n, p) in self.cells[0].named_parameters() {
939 params.insert(n, p);
940 }
941 } else {
942 for (i, cell) in self.cells.iter().enumerate() {
943 for (n, p) in cell.named_parameters() {
944 params.insert(format!("cells.{i}.{n}"), p);
945 }
946 }
947 }
948 params
949 }
950
951 fn name(&self) -> &'static str {
952 "GRU"
953 }
954}
955
956impl GRU {
957 pub fn forward_mean(&self, input: &Variable) -> Variable {
960 let shape = input.shape();
961 let (batch_size, seq_len, input_features) = if self.batch_first {
962 (shape[0], shape[1], shape[2])
963 } else {
964 (shape[1], shape[0], shape[2])
965 };
966
967 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
968 .map(|_| {
969 Variable::new(
970 zeros(&[batch_size, self.hidden_size]),
971 input.requires_grad(),
972 )
973 })
974 .collect();
975
976 let cell0 = &self.cells[0];
978 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
979 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
980 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
981 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
982
983 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
985 let bias_hh_0 = cell0.bias_hh.variable();
986
987 let mut output_sum: Option<Variable> = None;
988 let hs = self.hidden_size;
989
990 for t in 0..seq_len {
991 let ih_t = ih_all_3d.select(1, t);
993 let hidden = &hidden_states[0];
994 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
995
996 let ih_r = ih_t.narrow(1, 0, hs);
997 let ih_z = ih_t.narrow(1, hs, hs);
998 let ih_n = ih_t.narrow(1, 2 * hs, hs);
999 let hh_r = hh.narrow(1, 0, hs);
1000 let hh_z = hh.narrow(1, hs, hs);
1001 let hh_n = hh.narrow(1, 2 * hs, hs);
1002
1003 let r = ih_r.add_var(&hh_r).sigmoid();
1004 let z = ih_z.add_var(&hh_z).sigmoid();
1005 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1006 let h_minus_n = hidden.sub_var(&n);
1007 let h_new = n.add_var(&z.mul_var(&h_minus_n));
1008 hidden_states[0] = h_new.clone();
1009
1010 let mut layer_output = h_new;
1012 for l in 1..self.num_layers {
1013 let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
1014 hidden_states[l] = new_hidden.clone();
1015 layer_output = new_hidden;
1016 }
1017
1018 output_sum = Some(match output_sum {
1019 None => layer_output,
1020 Some(acc) => acc.add_var(&layer_output),
1021 });
1022 }
1023
1024 match output_sum {
1025 Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
1026 None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
1027 }
1028 }
1029
1030 pub fn forward_last(&self, input: &Variable) -> Variable {
1033 let shape = input.shape();
1034 let (batch_size, seq_len, input_features) = if self.batch_first {
1035 (shape[0], shape[1], shape[2])
1036 } else {
1037 (shape[1], shape[0], shape[2])
1038 };
1039
1040 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
1041 .map(|_| {
1042 Variable::new(
1043 zeros(&[batch_size, self.hidden_size]),
1044 input.requires_grad(),
1045 )
1046 })
1047 .collect();
1048
1049 let cell0 = &self.cells[0];
1051 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
1052 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
1053 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
1054 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
1055
1056 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
1058 let bias_hh_0 = cell0.bias_hh.variable();
1059 let hs = self.hidden_size;
1060
1061 for t in 0..seq_len {
1062 let ih_t = ih_all_3d.select(1, t);
1064 let hidden = &hidden_states[0];
1065 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
1066
1067 let ih_r = ih_t.narrow(1, 0, hs);
1068 let ih_z = ih_t.narrow(1, hs, hs);
1069 let ih_n = ih_t.narrow(1, 2 * hs, hs);
1070 let hh_r = hh.narrow(1, 0, hs);
1071 let hh_z = hh.narrow(1, hs, hs);
1072 let hh_n = hh.narrow(1, 2 * hs, hs);
1073
1074 let r = ih_r.add_var(&hh_r).sigmoid();
1075 let z = ih_z.add_var(&hh_z).sigmoid();
1076 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1077 let h_minus_n = hidden.sub_var(&n);
1078 let h_new = n.add_var(&z.mul_var(&h_minus_n));
1079 hidden_states[0] = h_new.clone();
1080
1081 let mut layer_input = h_new;
1083
1084 for (layer_idx, cell) in self.cells.iter().enumerate().skip(1) {
1085 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
1086 hidden_states[layer_idx] = new_hidden.clone();
1087 layer_input = new_hidden;
1088 }
1089 }
1090
1091 hidden_states
1093 .pop()
1094 .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
1095 }
1096
1097 fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
1101 if outputs.is_empty() {
1102 return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
1103 }
1104
1105 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(1)).collect();
1107 let refs: Vec<&Variable> = unsqueezed.iter().collect();
1108 Variable::cat(&refs, 1)
1109 }
1110}
1111
1112#[cfg(test)]
1117mod tests {
1118 use super::*;
1119 use axonml_tensor::Tensor;
1120
1121 #[test]
1122 fn test_rnn_cell() {
1123 let cell = RNNCell::new(10, 20);
1124 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
1125 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
1126 let output = cell.forward_step(&input, &hidden);
1127 assert_eq!(output.shape(), vec![2, 20]);
1128 }
1129
1130 #[test]
1131 fn test_rnn() {
1132 let rnn = RNN::new(10, 20, 2);
1133 let input = Variable::new(
1134 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1135 false,
1136 );
1137 let output = rnn.forward(&input);
1138 assert_eq!(output.shape(), vec![2, 5, 20]);
1139 }
1140
1141 #[test]
1142 fn test_lstm() {
1143 let lstm = LSTM::new(10, 20, 1);
1144 let input = Variable::new(
1145 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1146 false,
1147 );
1148 let output = lstm.forward(&input);
1149 assert_eq!(output.shape(), vec![2, 5, 20]);
1150 }
1151
1152 #[test]
1153 fn test_gru_gradients_reach_parameters() {
1154 let gru = GRU::new(4, 8, 1);
1155 let input = Variable::new(
1156 Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1157 true,
1158 );
1159 let output = gru.forward(&input);
1160 println!(
1161 "Output shape: {:?}, requires_grad: {}",
1162 output.shape(),
1163 output.requires_grad()
1164 );
1165 let loss = output.sum();
1166 println!(
1167 "Loss: {:?}, requires_grad: {}",
1168 loss.data().to_vec(),
1169 loss.requires_grad()
1170 );
1171 loss.backward();
1172
1173 println!(
1175 "Input grad: {:?}",
1176 input
1177 .grad()
1178 .map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>())
1179 );
1180
1181 let params = gru.parameters();
1182 println!("Number of parameters: {}", params.len());
1183 let mut has_grad = false;
1184 for (i, p) in params.iter().enumerate() {
1185 let grad = p.grad();
1186 match grad {
1187 Some(g) => {
1188 let gv = g.to_vec();
1189 let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1190 println!(
1191 "Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1192 i,
1193 p.shape(),
1194 p.requires_grad(),
1195 sum_abs
1196 );
1197 if sum_abs > 0.0 {
1198 has_grad = true;
1199 }
1200 }
1201 None => {
1202 println!(
1203 "Param {} shape {:?} requires_grad={}: NO GRADIENT",
1204 i,
1205 p.shape(),
1206 p.requires_grad()
1207 );
1208 }
1209 }
1210 }
1211 assert!(
1212 has_grad,
1213 "At least one GRU parameter should have non-zero gradients"
1214 );
1215 }
1216
1217 #[test]
1222 fn test_lstm_cell_forward_step() {
1223 let cell = LSTMCell::new(8, 16);
1224 let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1225 let hidden = Variable::new(
1226 Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1227 false,
1228 );
1229 let cell_state = Variable::new(
1230 Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1231 false,
1232 );
1233 let hx = (hidden, cell_state);
1234 let (h, c) = cell.forward_step(&input, &hx);
1235 assert_eq!(h.shape(), vec![2, 16]);
1236 assert_eq!(c.shape(), vec![2, 16]);
1237 }
1238
1239 #[test]
1240 fn test_lstm_multi_layer() {
1241 let lstm = LSTM::new(8, 16, 3); assert_eq!(lstm.num_layers(), 3);
1243 assert_eq!(lstm.hidden_size(), 16);
1244
1245 let input = Variable::new(
1246 Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1247 false,
1248 );
1249 let output = lstm.forward(&input);
1250 assert_eq!(output.shape(), vec![2, 5, 16]);
1251 }
1252
1253 #[test]
1254 fn test_lstm_forward_last() {
1255 let lstm = LSTM::new(8, 16, 1);
1256 let input = Variable::new(
1257 Tensor::from_vec(vec![1.0; 2 * 10 * 8], &[2, 10, 8]).unwrap(),
1258 false,
1259 );
1260 let output = lstm.forward(&input);
1263 assert_eq!(output.shape(), vec![2, 10, 16]);
1264
1265 let out_vec = output.data().to_vec();
1267 let last_t0 = &out_vec[9 * 16..10 * 16]; assert!(
1269 last_t0.iter().all(|v| v.is_finite()),
1270 "Last output should be finite"
1271 );
1272 }
1273
1274 #[test]
1275 fn test_lstm_gradient_flow() {
1276 let lstm = LSTM::new(4, 8, 1);
1277 let input = Variable::new(
1278 Tensor::from_vec(vec![0.5; 3 * 4], &[1, 3, 4]).unwrap(),
1279 true,
1280 );
1281 let output = lstm.forward(&input);
1282 let loss = output.sum();
1283 loss.backward();
1284
1285 let input_grad = input
1286 .grad()
1287 .expect("Input should have gradient through LSTM");
1288 assert_eq!(input_grad.shape(), &[1, 3, 4]);
1289 assert!(
1290 input_grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1291 "LSTM should propagate gradients to input"
1292 );
1293
1294 let params = lstm.parameters();
1296 let grads_exist = params.iter().any(|p| {
1297 p.grad()
1298 .is_some_and(|g| g.to_vec().iter().any(|v| v.abs() > 0.0))
1299 });
1300 assert!(grads_exist, "LSTM parameters should have gradients");
1301 }
1302
1303 #[test]
1304 fn test_lstm_different_sequence_lengths() {
1305 let lstm = LSTM::new(4, 8, 1);
1306
1307 let short = Variable::new(
1309 Tensor::from_vec(vec![1.0; 2 * 4], &[1, 2, 4]).unwrap(),
1310 false,
1311 );
1312 let out_short = lstm.forward(&short);
1313 assert_eq!(out_short.shape(), vec![1, 2, 8]);
1314
1315 let long = Variable::new(
1317 Tensor::from_vec(vec![1.0; 20 * 4], &[1, 20, 4]).unwrap(),
1318 false,
1319 );
1320 let out_long = lstm.forward(&long);
1321 assert_eq!(out_long.shape(), vec![1, 20, 8]);
1322 }
1323
1324 #[test]
1325 fn test_lstm_parameters_count() {
1326 let lstm = LSTM::new(10, 20, 1);
1329 let n = lstm.parameters().iter().map(|p| p.numel()).sum::<usize>();
1330 assert!(n > 0, "LSTM should have parameters");
1332 }
1333
1334 #[test]
1339 fn test_gru_cell_forward_step() {
1340 let cell = GRUCell::new(8, 16);
1341 assert_eq!(cell.input_size(), 8);
1342 assert_eq!(cell.hidden_size(), 16);
1343
1344 let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1345 let hidden = Variable::new(
1346 Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1347 false,
1348 );
1349 let output = cell.forward_step(&input, &hidden);
1350 assert_eq!(output.shape(), vec![2, 16]);
1351 }
1352
1353 #[test]
1354 fn test_gru_multi_layer() {
1355 let gru = GRU::new(8, 16, 2);
1356 assert_eq!(gru.num_layers(), 2);
1357 assert_eq!(gru.hidden_size(), 16);
1358
1359 let input = Variable::new(
1360 Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1361 false,
1362 );
1363 let output = gru.forward(&input);
1364 assert_eq!(output.shape(), vec![2, 5, 16]);
1365 }
1366
1367 #[test]
1368 fn test_gru_forward_mean() {
1369 let gru = GRU::new(4, 8, 1);
1370 let input = Variable::new(
1371 Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1372 false,
1373 );
1374 let mean_out = gru.forward_mean(&input);
1375 assert_eq!(mean_out.shape(), vec![2, 8]);
1377 }
1378
1379 #[test]
1380 fn test_gru_forward_last() {
1381 let gru = GRU::new(4, 8, 1);
1382 let input = Variable::new(
1383 Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1384 false,
1385 );
1386 let last_out = gru.forward_last(&input);
1387 assert_eq!(last_out.shape(), vec![2, 8]);
1389 }
1390
1391 #[test]
1392 fn test_gru_gradient_flow_to_input() {
1393 let gru = GRU::new(4, 8, 1);
1394 let input = Variable::new(
1395 Tensor::from_vec(vec![0.5; 3 * 4], &[1, 3, 4]).unwrap(),
1396 true,
1397 );
1398 let output = gru.forward(&input);
1399 output.sum().backward();
1400
1401 let grad = input
1402 .grad()
1403 .expect("Input should have gradient through GRU");
1404 assert_eq!(grad.shape(), &[1, 3, 4]);
1405 assert!(
1406 grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1407 "GRU should propagate gradients"
1408 );
1409 }
1410
1411 #[test]
1412 fn test_gru_hidden_state_evolves() {
1413 let gru = GRU::new(4, 8, 1);
1414 let input = Variable::new(
1415 Tensor::from_vec(vec![1.0; 5 * 4], &[1, 5, 4]).unwrap(),
1416 false,
1417 );
1418 let output = gru.forward(&input);
1419 let out_vec = output.data().to_vec();
1420
1421 let t0 = &out_vec[0..8];
1423 let t4 = &out_vec[4 * 8..5 * 8];
1424 let diff: f32 = t0.iter().zip(t4.iter()).map(|(a, b)| (a - b).abs()).sum();
1425 assert!(
1426 diff > 1e-6,
1427 "GRU hidden state should evolve over time, diff={}",
1428 diff
1429 );
1430 }
1431
1432 #[test]
1437 fn test_rnn_cell_gradient_flow() {
1438 let cell = RNNCell::new(4, 8);
1439 let input = Variable::new(Tensor::from_vec(vec![1.0; 4], &[1, 4]).unwrap(), true);
1440 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 8], &[1, 8]).unwrap(), false);
1441 let out = cell.forward_step(&input, &hidden);
1442 out.sum().backward();
1443
1444 let grad = input.grad().expect("RNNCell should propagate gradients");
1445 assert_eq!(grad.shape(), &[1, 4]);
1446 }
1447
1448 #[test]
1449 fn test_rnn_multi_layer() {
1450 let rnn = RNN::with_options(8, 16, 3, true); let input = Variable::new(
1452 Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1453 false,
1454 );
1455 let output = rnn.forward(&input);
1456 assert_eq!(output.shape(), vec![2, 5, 16]);
1457 }
1458
1459 #[test]
1464 fn test_lstm_outputs_are_bounded() {
1465 let lstm = LSTM::new(4, 8, 1);
1467 let input = Variable::new(
1468 Tensor::from_vec(vec![100.0; 10 * 4], &[1, 10, 4]).unwrap(),
1469 false,
1470 );
1471 let output = lstm.forward(&input);
1472 let out_vec = output.data().to_vec();
1473
1474 for v in &out_vec {
1476 assert!(v.is_finite(), "LSTM output should be finite, got {}", v);
1477 assert!(
1478 v.abs() <= 1.0 + 1e-5,
1479 "LSTM output should be bounded by tanh: got {}",
1480 v
1481 );
1482 }
1483 }
1484
1485 #[test]
1486 fn test_gru_outputs_finite_with_large_input() {
1487 let gru = GRU::new(4, 8, 1);
1488 let input = Variable::new(
1489 Tensor::from_vec(vec![50.0; 5 * 4], &[1, 5, 4]).unwrap(),
1490 false,
1491 );
1492 let output = gru.forward(&input);
1493 assert!(
1494 output.data().to_vec().iter().all(|v| v.is_finite()),
1495 "GRU should produce finite outputs for large inputs"
1496 );
1497 }
1498}