1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20
21use crate::init::{xavier_uniform, zeros};
22use crate::module::Module;
23use crate::parameter::Parameter;
24
25pub struct RNNCell {
33 pub weight_ih: Parameter,
35 pub weight_hh: Parameter,
37 pub bias_ih: Parameter,
39 pub bias_hh: Parameter,
41 input_size: usize,
43 hidden_size: usize,
45}
46
47impl RNNCell {
48 pub fn new(input_size: usize, hidden_size: usize) -> Self {
50 Self {
51 weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
52 weight_hh: Parameter::named(
53 "weight_hh",
54 xavier_uniform(hidden_size, hidden_size),
55 true,
56 ),
57 bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
58 bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
59 input_size,
60 hidden_size,
61 }
62 }
63
64 pub fn input_size(&self) -> usize {
66 self.input_size
67 }
68
69 pub fn hidden_size(&self) -> usize {
71 self.hidden_size
72 }
73
74 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
76 let input_features = input.data().shape().last().copied().unwrap_or(0);
77 assert_eq!(
78 input_features, self.input_size,
79 "RNNCell: expected input size {}, got {}",
80 self.input_size, input_features
81 );
82 let weight_ih = self.weight_ih.variable();
84 let weight_ih_t = weight_ih.transpose(0, 1);
85 let ih = input.matmul(&weight_ih_t);
86 let bias_ih = self.bias_ih.variable();
87 let ih = ih.add_var(&bias_ih);
88
89 let weight_hh = self.weight_hh.variable();
91 let weight_hh_t = weight_hh.transpose(0, 1);
92 let hh = hidden.matmul(&weight_hh_t);
93 let bias_hh = self.bias_hh.variable();
94 let hh = hh.add_var(&bias_hh);
95
96 ih.add_var(&hh).tanh()
98 }
99}
100
101impl Module for RNNCell {
102 fn forward(&self, input: &Variable) -> Variable {
103 let batch_size = input.shape()[0];
105 let hidden = Variable::new(
106 zeros(&[batch_size, self.hidden_size]),
107 input.requires_grad(),
108 );
109 self.forward_step(input, &hidden)
110 }
111
112 fn parameters(&self) -> Vec<Parameter> {
113 vec![
114 self.weight_ih.clone(),
115 self.weight_hh.clone(),
116 self.bias_ih.clone(),
117 self.bias_hh.clone(),
118 ]
119 }
120
121 fn named_parameters(&self) -> HashMap<String, Parameter> {
122 let mut params = HashMap::new();
123 params.insert("weight_ih".to_string(), self.weight_ih.clone());
124 params.insert("weight_hh".to_string(), self.weight_hh.clone());
125 params.insert("bias_ih".to_string(), self.bias_ih.clone());
126 params.insert("bias_hh".to_string(), self.bias_hh.clone());
127 params
128 }
129
130 fn name(&self) -> &'static str {
131 "RNNCell"
132 }
133}
134
135pub struct RNN {
143 cells: Vec<RNNCell>,
145 _input_size: usize,
147 hidden_size: usize,
149 num_layers: usize,
151 batch_first: bool,
153}
154
155impl RNN {
156 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
158 Self::with_options(input_size, hidden_size, num_layers, true)
159 }
160
161 pub fn with_options(
163 input_size: usize,
164 hidden_size: usize,
165 num_layers: usize,
166 batch_first: bool,
167 ) -> Self {
168 let mut cells = Vec::with_capacity(num_layers);
169
170 cells.push(RNNCell::new(input_size, hidden_size));
172
173 for _ in 1..num_layers {
175 cells.push(RNNCell::new(hidden_size, hidden_size));
176 }
177
178 Self {
179 cells,
180 _input_size: input_size,
181 hidden_size,
182 num_layers,
183 batch_first,
184 }
185 }
186}
187
188impl Module for RNN {
189 fn forward(&self, input: &Variable) -> Variable {
190 let shape = input.shape();
191 let (batch_size, seq_len, input_features) = if self.batch_first {
192 (shape[0], shape[1], shape[2])
193 } else {
194 (shape[1], shape[0], shape[2])
195 };
196
197 let mut hiddens: Vec<Variable> = (0..self.num_layers)
199 .map(|_| {
200 Variable::new(
201 zeros(&[batch_size, self.hidden_size]),
202 input.requires_grad(),
203 )
204 })
205 .collect();
206
207 let cell0 = &self.cells[0];
209 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
210 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
211 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
212 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, self.hidden_size]);
213
214 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
216 let bias_hh_0 = cell0.bias_hh.variable();
217
218 let mut outputs = Vec::with_capacity(seq_len);
219
220 for t in 0..seq_len {
221 let ih_t = ih_all_3d.select(1, t);
223 let hh = hiddens[0].matmul(&w_hh_t_0).add_var(&bias_hh_0);
224 hiddens[0] = ih_t.add_var(&hh).tanh();
225
226 for l in 1..self.num_layers {
228 let layer_input = hiddens[l - 1].clone();
229 hiddens[l] = self.cells[l].forward_step(&layer_input, &hiddens[l]);
230 }
231
232 outputs.push(hiddens[self.num_layers - 1].clone());
233 }
234
235 let time_dim = usize::from(self.batch_first);
237 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
238 let refs: Vec<&Variable> = unsqueezed.iter().collect();
239 Variable::cat(&refs, time_dim)
240 }
241
242 fn parameters(&self) -> Vec<Parameter> {
243 self.cells.iter().flat_map(|c| c.parameters()).collect()
244 }
245
246 fn name(&self) -> &'static str {
247 "RNN"
248 }
249}
250
251pub struct LSTMCell {
257 pub weight_ih: Parameter,
259 pub weight_hh: Parameter,
261 pub bias_ih: Parameter,
263 pub bias_hh: Parameter,
265 input_size: usize,
267 hidden_size: usize,
269}
270
271impl LSTMCell {
272 pub fn new(input_size: usize, hidden_size: usize) -> Self {
274 Self {
276 weight_ih: Parameter::named(
277 "weight_ih",
278 xavier_uniform(input_size, 4 * hidden_size),
279 true,
280 ),
281 weight_hh: Parameter::named(
282 "weight_hh",
283 xavier_uniform(hidden_size, 4 * hidden_size),
284 true,
285 ),
286 bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
287 bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
288 input_size,
289 hidden_size,
290 }
291 }
292
293 pub fn input_size(&self) -> usize {
295 self.input_size
296 }
297
298 pub fn hidden_size(&self) -> usize {
300 self.hidden_size
301 }
302
303 pub fn forward_step(
305 &self,
306 input: &Variable,
307 hx: &(Variable, Variable),
308 ) -> (Variable, Variable) {
309 let input_features = input.data().shape().last().copied().unwrap_or(0);
310 assert_eq!(
311 input_features, self.input_size,
312 "LSTMCell: expected input size {}, got {}",
313 self.input_size, input_features
314 );
315
316 let (h, c) = hx;
317
318 let weight_ih = self.weight_ih.variable();
320 let weight_ih_t = weight_ih.transpose(0, 1);
321 let ih = input.matmul(&weight_ih_t);
322 let bias_ih = self.bias_ih.variable();
323 let ih = ih.add_var(&bias_ih);
324
325 let weight_hh = self.weight_hh.variable();
326 let weight_hh_t = weight_hh.transpose(0, 1);
327 let hh = h.matmul(&weight_hh_t);
328 let bias_hh = self.bias_hh.variable();
329 let hh = hh.add_var(&bias_hh);
330
331 let gates = ih.add_var(&hh);
332 let hs = self.hidden_size;
333
334 let i = gates.narrow(1, 0, hs).sigmoid();
336 let f = gates.narrow(1, hs, hs).sigmoid();
337 let g = gates.narrow(1, 2 * hs, hs).tanh();
338 let o = gates.narrow(1, 3 * hs, hs).sigmoid();
339
340 let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
342
343 let h_new = o.mul_var(&c_new.tanh());
345
346 (h_new, c_new)
347 }
348}
349
350impl Module for LSTMCell {
351 fn forward(&self, input: &Variable) -> Variable {
352 let batch_size = input.shape()[0];
353 let h = Variable::new(
354 zeros(&[batch_size, self.hidden_size]),
355 input.requires_grad(),
356 );
357 let c = Variable::new(
358 zeros(&[batch_size, self.hidden_size]),
359 input.requires_grad(),
360 );
361 let (h_new, _) = self.forward_step(input, &(h, c));
362 h_new
363 }
364
365 fn parameters(&self) -> Vec<Parameter> {
366 vec![
367 self.weight_ih.clone(),
368 self.weight_hh.clone(),
369 self.bias_ih.clone(),
370 self.bias_hh.clone(),
371 ]
372 }
373
374 fn named_parameters(&self) -> HashMap<String, Parameter> {
375 let mut params = HashMap::new();
376 params.insert("weight_ih".to_string(), self.weight_ih.clone());
377 params.insert("weight_hh".to_string(), self.weight_hh.clone());
378 params.insert("bias_ih".to_string(), self.bias_ih.clone());
379 params.insert("bias_hh".to_string(), self.bias_hh.clone());
380 params
381 }
382
383 fn name(&self) -> &'static str {
384 "LSTMCell"
385 }
386}
387
388pub struct LSTM {
394 cells: Vec<LSTMCell>,
396 input_size: usize,
398 hidden_size: usize,
400 num_layers: usize,
402 batch_first: bool,
404}
405
406impl LSTM {
407 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
409 Self::with_options(input_size, hidden_size, num_layers, true)
410 }
411
412 pub fn with_options(
414 input_size: usize,
415 hidden_size: usize,
416 num_layers: usize,
417 batch_first: bool,
418 ) -> Self {
419 let mut cells = Vec::with_capacity(num_layers);
420 cells.push(LSTMCell::new(input_size, hidden_size));
421 for _ in 1..num_layers {
422 cells.push(LSTMCell::new(hidden_size, hidden_size));
423 }
424
425 Self {
426 cells,
427 input_size,
428 hidden_size,
429 num_layers,
430 batch_first,
431 }
432 }
433
434 pub fn input_size(&self) -> usize {
436 self.input_size
437 }
438
439 pub fn hidden_size(&self) -> usize {
441 self.hidden_size
442 }
443
444 pub fn num_layers(&self) -> usize {
446 self.num_layers
447 }
448}
449
450impl Module for LSTM {
451 fn forward(&self, input: &Variable) -> Variable {
452 let shape = input.shape();
453 let (batch_size, seq_len, input_features) = if self.batch_first {
454 (shape[0], shape[1], shape[2])
455 } else {
456 (shape[1], shape[0], shape[2])
457 };
458
459 let lstm_input_device = input.data().device();
460 #[cfg(feature = "cuda")]
461 let lstm_on_gpu = lstm_input_device.is_gpu();
462 #[cfg(not(feature = "cuda"))]
463 let lstm_on_gpu = false;
464
465 let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
466 .map(|_| {
467 let make_h = || {
468 let h_cpu = zeros(&[batch_size, self.hidden_size]);
469 let h_tensor = if lstm_on_gpu {
470 h_cpu
471 .to_device(lstm_input_device)
472 .expect("LSTM: failed to move hidden state to GPU")
473 } else {
474 h_cpu
475 };
476 Variable::new(h_tensor, input.requires_grad())
477 };
478 (make_h(), make_h())
479 })
480 .collect();
481
482 let cell0 = &self.cells[0];
487 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
488 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
489 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
490 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 4 * self.hidden_size]);
492
493 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
495 let bias_hh_0 = cell0.bias_hh.variable();
496
497 let mut outputs = Vec::with_capacity(seq_len);
498
499 #[cfg(feature = "cuda")]
501 let on_gpu = input.data().device().is_gpu();
502 #[cfg(not(feature = "cuda"))]
503 let on_gpu = false;
504
505 for t in 0..seq_len {
506 let ih_t = ih_all_3d.select(1, t);
508 let (h, c) = &states[0];
509
510 let hh = h.matmul(&w_hh_t_0).add_var(&bias_hh_0);
512
513 let gates = ih_t.add_var(&hh);
515
516 if on_gpu {
517 #[cfg(feature = "cuda")]
520 {
521 let hs = self.hidden_size;
522 let gates_data = gates.data();
523 let c_data = c.data();
524
525 if let Some((h_tensor, c_tensor)) = gates_data.lstm_gates_fused(&c_data, hs) {
526 let saved_gates = gates_data.clone();
528 let saved_c_prev = c_data.clone();
529 let saved_c_new = c_tensor.clone();
530
531 let backward_fn = axonml_autograd::LstmGatesBackward::new(
533 gates.grad_fn().cloned(),
534 c.grad_fn().cloned(),
535 saved_gates,
536 saved_c_prev,
537 saved_c_new,
538 hs,
539 );
540 let grad_fn = axonml_autograd::GradFn::new(backward_fn);
541
542 let fused_requires_grad = gates.requires_grad() || c.requires_grad();
543 let h_new = Variable::from_operation(
544 h_tensor,
545 grad_fn.clone(),
546 fused_requires_grad,
547 );
548 let c_new =
549 Variable::from_operation(c_tensor, grad_fn, fused_requires_grad);
550 states[0] = (h_new, c_new);
551 }
552 }
553 } else {
554 let hs = self.hidden_size;
556 let i_gate = gates.narrow(1, 0, hs).sigmoid();
557 let f_gate = gates.narrow(1, hs, hs).sigmoid();
558 let g_gate = gates.narrow(1, 2 * hs, hs).tanh();
559 let o_gate = gates.narrow(1, 3 * hs, hs).sigmoid();
560 let c_new = f_gate.mul_var(c).add_var(&i_gate.mul_var(&g_gate));
561 let h_new = o_gate.mul_var(&c_new.tanh());
562 states[0] = (h_new, c_new);
563 }
564
565 for l in 1..self.num_layers {
567 let layer_input = states[l - 1].0.clone();
568 states[l] = self.cells[l].forward_step(&layer_input, &states[l]);
569 }
570
571 outputs.push(states[self.num_layers - 1].0.clone());
572 }
573
574 let time_dim = usize::from(self.batch_first);
576 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(time_dim)).collect();
577 let refs: Vec<&Variable> = unsqueezed.iter().collect();
578 Variable::cat(&refs, time_dim)
579 }
580
581 fn parameters(&self) -> Vec<Parameter> {
582 self.cells.iter().flat_map(|c| c.parameters()).collect()
583 }
584
585 fn named_parameters(&self) -> HashMap<String, Parameter> {
586 let mut params = HashMap::new();
587 if self.cells.len() == 1 {
588 for (n, p) in self.cells[0].named_parameters() {
590 params.insert(n, p);
591 }
592 } else {
593 for (i, cell) in self.cells.iter().enumerate() {
594 for (n, p) in cell.named_parameters() {
595 params.insert(format!("cells.{i}.{n}"), p);
596 }
597 }
598 }
599 params
600 }
601
602 fn name(&self) -> &'static str {
603 "LSTM"
604 }
605}
606
607pub struct GRUCell {
619 pub weight_ih: Parameter,
621 pub weight_hh: Parameter,
623 pub bias_ih: Parameter,
625 pub bias_hh: Parameter,
627 input_size: usize,
629 hidden_size: usize,
631}
632
633impl GRUCell {
634 pub fn new(input_size: usize, hidden_size: usize) -> Self {
636 Self {
637 weight_ih: Parameter::named(
638 "weight_ih",
639 xavier_uniform(input_size, 3 * hidden_size),
640 true,
641 ),
642 weight_hh: Parameter::named(
643 "weight_hh",
644 xavier_uniform(hidden_size, 3 * hidden_size),
645 true,
646 ),
647 bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
648 bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
649 input_size,
650 hidden_size,
651 }
652 }
653
654 pub fn input_size(&self) -> usize {
656 self.input_size
657 }
658
659 pub fn hidden_size(&self) -> usize {
661 self.hidden_size
662 }
663}
664
665impl GRUCell {
666 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
676 let _batch_size = input.shape()[0];
677 let hidden_size = self.hidden_size;
678
679 let weight_ih = self.weight_ih.variable();
681 let weight_hh = self.weight_hh.variable();
682 let bias_ih = self.bias_ih.variable();
683 let bias_hh = self.bias_hh.variable();
684
685 let weight_ih_t = weight_ih.transpose(0, 1);
688 let ih = input.matmul(&weight_ih_t).add_var(&bias_ih);
689
690 let weight_hh_t = weight_hh.transpose(0, 1);
693 let hh = hidden.matmul(&weight_hh_t).add_var(&bias_hh);
694
695 let ih_r = ih.narrow(1, 0, hidden_size);
698 let ih_z = ih.narrow(1, hidden_size, hidden_size);
699 let ih_n = ih.narrow(1, 2 * hidden_size, hidden_size);
700
701 let hh_r = hh.narrow(1, 0, hidden_size);
702 let hh_z = hh.narrow(1, hidden_size, hidden_size);
703 let hh_n = hh.narrow(1, 2 * hidden_size, hidden_size);
704
705 let r = ih_r.add_var(&hh_r).sigmoid();
708
709 let z = ih_z.add_var(&hh_z).sigmoid();
711
712 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
714
715 let h_minus_n = hidden.sub_var(&n);
718 n.add_var(&z.mul_var(&h_minus_n))
719 }
720}
721
722impl Module for GRUCell {
723 fn forward(&self, input: &Variable) -> Variable {
724 let batch_size = input.shape()[0];
725
726 let hidden = Variable::new(
728 zeros(&[batch_size, self.hidden_size]),
729 input.requires_grad(),
730 );
731
732 self.forward_step(input, &hidden)
733 }
734
735 fn parameters(&self) -> Vec<Parameter> {
736 vec![
737 self.weight_ih.clone(),
738 self.weight_hh.clone(),
739 self.bias_ih.clone(),
740 self.bias_hh.clone(),
741 ]
742 }
743
744 fn named_parameters(&self) -> HashMap<String, Parameter> {
745 let mut params = HashMap::new();
746 params.insert("weight_ih".to_string(), self.weight_ih.clone());
747 params.insert("weight_hh".to_string(), self.weight_hh.clone());
748 params.insert("bias_ih".to_string(), self.bias_ih.clone());
749 params.insert("bias_hh".to_string(), self.bias_hh.clone());
750 params
751 }
752
753 fn name(&self) -> &'static str {
754 "GRUCell"
755 }
756}
757
758pub struct GRU {
760 cells: Vec<GRUCell>,
762 hidden_size: usize,
764 num_layers: usize,
766 batch_first: bool,
768}
769
770impl GRU {
771 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
773 let mut cells = Vec::with_capacity(num_layers);
774 cells.push(GRUCell::new(input_size, hidden_size));
775 for _ in 1..num_layers {
776 cells.push(GRUCell::new(hidden_size, hidden_size));
777 }
778 Self {
779 cells,
780 hidden_size,
781 num_layers,
782 batch_first: true,
783 }
784 }
785
786 pub fn hidden_size(&self) -> usize {
788 self.hidden_size
789 }
790
791 pub fn num_layers(&self) -> usize {
793 self.num_layers
794 }
795}
796
797impl Module for GRU {
798 fn forward(&self, input: &Variable) -> Variable {
799 let shape = input.shape();
800 let (batch_size, seq_len, input_features) = if self.batch_first {
801 (shape[0], shape[1], shape[2])
802 } else {
803 (shape[1], shape[0], shape[2])
804 };
805
806 #[cfg(feature = "cuda")]
808 let on_gpu = input.data().device().is_gpu();
809 #[cfg(not(feature = "cuda"))]
810 let on_gpu = false;
811
812 let input_device = input.data().device();
813
814 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
817 .map(|_| {
818 let h_cpu = zeros(&[batch_size, self.hidden_size]);
819 let h_tensor = if on_gpu {
820 h_cpu
821 .to_device(input_device)
822 .expect("GRU: failed to move hidden state to GPU")
823 } else {
824 h_cpu
825 };
826 Variable::new(h_tensor, input.requires_grad())
827 })
828 .collect();
829
830 let cell0 = &self.cells[0];
833 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
834 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
835 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
836 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
837
838 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
840 let bias_hh_0 = cell0.bias_hh.variable();
841
842 let mut output_vars: Vec<Variable> = Vec::with_capacity(seq_len);
843
844 for t in 0..seq_len {
845 let ih_t = ih_all_3d.select(1, t);
847 let hidden = &hidden_states[0];
848 let hs = self.hidden_size;
849
850 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
851
852 if on_gpu {
853 #[cfg(feature = "cuda")]
856 {
857 let ih_data = ih_t.data();
858 let hh_data = hh.data();
859 let h_data = hidden.data();
860
861 if let Some(h_tensor) = ih_data.gru_gates_fused(&hh_data, &h_data, hs) {
862 let saved_ih = ih_data.clone();
864 let saved_hh = hh_data.clone();
865 let saved_h_prev = h_data.clone();
866
867 let backward_fn = axonml_autograd::GruGatesBackward::new(
869 ih_t.grad_fn().cloned(),
870 hh.grad_fn().cloned(),
871 hidden.grad_fn().cloned(),
872 saved_ih,
873 saved_hh,
874 saved_h_prev,
875 hs,
876 );
877 let grad_fn = axonml_autograd::GradFn::new(backward_fn);
878
879 let fused_requires_grad =
885 ih_t.requires_grad() || hh.requires_grad() || hidden.requires_grad();
886 let h_new =
887 Variable::from_operation(h_tensor, grad_fn, fused_requires_grad);
888 hidden_states[0] = h_new;
889 }
890 }
891 } else {
892 let ih_r = ih_t.narrow(1, 0, hs);
894 let ih_z = ih_t.narrow(1, hs, hs);
895 let ih_n = ih_t.narrow(1, 2 * hs, hs);
896 let hh_r = hh.narrow(1, 0, hs);
897 let hh_z = hh.narrow(1, hs, hs);
898 let hh_n = hh.narrow(1, 2 * hs, hs);
899
900 let r = ih_r.add_var(&hh_r).sigmoid();
901 let z = ih_z.add_var(&hh_z).sigmoid();
902 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
903 let h_minus_n = hidden.sub_var(&n);
904 let h_new = n.add_var(&z.mul_var(&h_minus_n));
905 hidden_states[0] = h_new;
906 }
907
908 let mut layer_output = hidden_states[0].clone();
910 for l in 1..self.num_layers {
911 let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
912 hidden_states[l] = new_hidden.clone();
913 layer_output = new_hidden;
914 }
915
916 output_vars.push(layer_output);
917 }
918
919 self.stack_outputs(&output_vars, batch_size, seq_len)
921 }
922
923 fn parameters(&self) -> Vec<Parameter> {
924 self.cells.iter().flat_map(|c| c.parameters()).collect()
925 }
926
927 fn named_parameters(&self) -> HashMap<String, Parameter> {
928 let mut params = HashMap::new();
929 if self.cells.len() == 1 {
930 for (n, p) in self.cells[0].named_parameters() {
931 params.insert(n, p);
932 }
933 } else {
934 for (i, cell) in self.cells.iter().enumerate() {
935 for (n, p) in cell.named_parameters() {
936 params.insert(format!("cells.{i}.{n}"), p);
937 }
938 }
939 }
940 params
941 }
942
943 fn name(&self) -> &'static str {
944 "GRU"
945 }
946}
947
948impl GRU {
949 pub fn forward_mean(&self, input: &Variable) -> Variable {
952 let shape = input.shape();
953 let (batch_size, seq_len, input_features) = if self.batch_first {
954 (shape[0], shape[1], shape[2])
955 } else {
956 (shape[1], shape[0], shape[2])
957 };
958
959 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
960 .map(|_| {
961 Variable::new(
962 zeros(&[batch_size, self.hidden_size]),
963 input.requires_grad(),
964 )
965 })
966 .collect();
967
968 let cell0 = &self.cells[0];
970 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
971 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
972 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
973 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
974
975 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
977 let bias_hh_0 = cell0.bias_hh.variable();
978
979 let mut output_sum: Option<Variable> = None;
980 let hs = self.hidden_size;
981
982 for t in 0..seq_len {
983 let ih_t = ih_all_3d.select(1, t);
985 let hidden = &hidden_states[0];
986 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
987
988 let ih_r = ih_t.narrow(1, 0, hs);
989 let ih_z = ih_t.narrow(1, hs, hs);
990 let ih_n = ih_t.narrow(1, 2 * hs, hs);
991 let hh_r = hh.narrow(1, 0, hs);
992 let hh_z = hh.narrow(1, hs, hs);
993 let hh_n = hh.narrow(1, 2 * hs, hs);
994
995 let r = ih_r.add_var(&hh_r).sigmoid();
996 let z = ih_z.add_var(&hh_z).sigmoid();
997 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
998 let h_minus_n = hidden.sub_var(&n);
999 let h_new = n.add_var(&z.mul_var(&h_minus_n));
1000 hidden_states[0] = h_new.clone();
1001
1002 let mut layer_output = h_new;
1004 for l in 1..self.num_layers {
1005 let new_hidden = self.cells[l].forward_step(&layer_output, &hidden_states[l]);
1006 hidden_states[l] = new_hidden.clone();
1007 layer_output = new_hidden;
1008 }
1009
1010 output_sum = Some(match output_sum {
1011 None => layer_output,
1012 Some(acc) => acc.add_var(&layer_output),
1013 });
1014 }
1015
1016 match output_sum {
1017 Some(sum) => sum.mul_scalar(1.0 / seq_len as f32),
1018 None => Variable::new(zeros(&[batch_size, self.hidden_size]), false),
1019 }
1020 }
1021
1022 pub fn forward_last(&self, input: &Variable) -> Variable {
1025 let shape = input.shape();
1026 let (batch_size, seq_len, input_features) = if self.batch_first {
1027 (shape[0], shape[1], shape[2])
1028 } else {
1029 (shape[1], shape[0], shape[2])
1030 };
1031
1032 let mut hidden_states: Vec<Variable> = (0..self.num_layers)
1033 .map(|_| {
1034 Variable::new(
1035 zeros(&[batch_size, self.hidden_size]),
1036 input.requires_grad(),
1037 )
1038 })
1039 .collect();
1040
1041 let cell0 = &self.cells[0];
1043 let input_2d = input.reshape(&[batch_size * seq_len, input_features]);
1044 let w_ih_t = cell0.weight_ih.variable().transpose(0, 1);
1045 let ih_all = input_2d.matmul(&w_ih_t).add_var(&cell0.bias_ih.variable());
1046 let ih_all_3d = ih_all.reshape(&[batch_size, seq_len, 3 * self.hidden_size]);
1047
1048 let w_hh_t_0 = cell0.weight_hh.variable().transpose(0, 1);
1050 let bias_hh_0 = cell0.bias_hh.variable();
1051 let hs = self.hidden_size;
1052
1053 for t in 0..seq_len {
1054 let ih_t = ih_all_3d.select(1, t);
1056 let hidden = &hidden_states[0];
1057 let hh = hidden.matmul(&w_hh_t_0).add_var(&bias_hh_0);
1058
1059 let ih_r = ih_t.narrow(1, 0, hs);
1060 let ih_z = ih_t.narrow(1, hs, hs);
1061 let ih_n = ih_t.narrow(1, 2 * hs, hs);
1062 let hh_r = hh.narrow(1, 0, hs);
1063 let hh_z = hh.narrow(1, hs, hs);
1064 let hh_n = hh.narrow(1, 2 * hs, hs);
1065
1066 let r = ih_r.add_var(&hh_r).sigmoid();
1067 let z = ih_z.add_var(&hh_z).sigmoid();
1068 let n = ih_n.add_var(&r.mul_var(&hh_n)).tanh();
1069 let h_minus_n = hidden.sub_var(&n);
1070 let h_new = n.add_var(&z.mul_var(&h_minus_n));
1071 hidden_states[0] = h_new.clone();
1072
1073 let mut layer_input = h_new;
1075
1076 for (layer_idx, cell) in self.cells.iter().enumerate().skip(1) {
1077 let new_hidden = cell.forward_step(&layer_input, &hidden_states[layer_idx]);
1078 hidden_states[layer_idx] = new_hidden.clone();
1079 layer_input = new_hidden;
1080 }
1081 }
1082
1083 hidden_states
1085 .pop()
1086 .unwrap_or_else(|| Variable::new(zeros(&[batch_size, self.hidden_size]), false))
1087 }
1088
1089 fn stack_outputs(&self, outputs: &[Variable], batch_size: usize, _seq_len: usize) -> Variable {
1093 if outputs.is_empty() {
1094 return Variable::new(zeros(&[batch_size, 0, self.hidden_size]), false);
1095 }
1096
1097 let unsqueezed: Vec<Variable> = outputs.iter().map(|o| o.unsqueeze(1)).collect();
1099 let refs: Vec<&Variable> = unsqueezed.iter().collect();
1100 Variable::cat(&refs, 1)
1101 }
1102}
1103
1104#[cfg(test)]
1109mod tests {
1110 use super::*;
1111 use axonml_tensor::Tensor;
1112
1113 #[test]
1114 fn test_rnn_cell() {
1115 let cell = RNNCell::new(10, 20);
1116 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
1117 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
1118 let output = cell.forward_step(&input, &hidden);
1119 assert_eq!(output.shape(), vec![2, 20]);
1120 }
1121
1122 #[test]
1123 fn test_rnn() {
1124 let rnn = RNN::new(10, 20, 2);
1125 let input = Variable::new(
1126 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1127 false,
1128 );
1129 let output = rnn.forward(&input);
1130 assert_eq!(output.shape(), vec![2, 5, 20]);
1131 }
1132
1133 #[test]
1134 fn test_lstm() {
1135 let lstm = LSTM::new(10, 20, 1);
1136 let input = Variable::new(
1137 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
1138 false,
1139 );
1140 let output = lstm.forward(&input);
1141 assert_eq!(output.shape(), vec![2, 5, 20]);
1142 }
1143
1144 #[test]
1145 fn test_gru_gradients_reach_parameters() {
1146 let gru = GRU::new(4, 8, 1);
1147 let input = Variable::new(
1148 Tensor::from_vec(vec![0.5f32; 2 * 3 * 4], &[2, 3, 4]).unwrap(),
1149 true,
1150 );
1151 let output = gru.forward(&input);
1152 println!(
1153 "Output shape: {:?}, requires_grad: {}",
1154 output.shape(),
1155 output.requires_grad()
1156 );
1157 let loss = output.sum();
1158 println!(
1159 "Loss: {:?}, requires_grad: {}",
1160 loss.data().to_vec(),
1161 loss.requires_grad()
1162 );
1163 loss.backward();
1164
1165 println!(
1167 "Input grad: {:?}",
1168 input
1169 .grad()
1170 .map(|g| g.to_vec().iter().map(|x| x.abs()).sum::<f32>())
1171 );
1172
1173 let params = gru.parameters();
1174 println!("Number of parameters: {}", params.len());
1175 let mut has_grad = false;
1176 for (i, p) in params.iter().enumerate() {
1177 let grad = p.grad();
1178 match grad {
1179 Some(g) => {
1180 let gv = g.to_vec();
1181 let sum_abs: f32 = gv.iter().map(|x| x.abs()).sum();
1182 println!(
1183 "Param {} shape {:?} requires_grad={}: grad sum_abs={:.6}",
1184 i,
1185 p.shape(),
1186 p.requires_grad(),
1187 sum_abs
1188 );
1189 if sum_abs > 0.0 {
1190 has_grad = true;
1191 }
1192 }
1193 None => {
1194 println!(
1195 "Param {} shape {:?} requires_grad={}: NO GRADIENT",
1196 i,
1197 p.shape(),
1198 p.requires_grad()
1199 );
1200 }
1201 }
1202 }
1203 assert!(
1204 has_grad,
1205 "At least one GRU parameter should have non-zero gradients"
1206 );
1207 }
1208
1209 #[test]
1214 fn test_lstm_cell_forward_step() {
1215 let cell = LSTMCell::new(8, 16);
1216 let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1217 let hidden = Variable::new(
1218 Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1219 false,
1220 );
1221 let cell_state = Variable::new(
1222 Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1223 false,
1224 );
1225 let hx = (hidden, cell_state);
1226 let (h, c) = cell.forward_step(&input, &hx);
1227 assert_eq!(h.shape(), vec![2, 16]);
1228 assert_eq!(c.shape(), vec![2, 16]);
1229 }
1230
1231 #[test]
1232 fn test_lstm_multi_layer() {
1233 let lstm = LSTM::new(8, 16, 3); assert_eq!(lstm.num_layers(), 3);
1235 assert_eq!(lstm.hidden_size(), 16);
1236
1237 let input = Variable::new(
1238 Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1239 false,
1240 );
1241 let output = lstm.forward(&input);
1242 assert_eq!(output.shape(), vec![2, 5, 16]);
1243 }
1244
1245 #[test]
1246 fn test_lstm_forward_last() {
1247 let lstm = LSTM::new(8, 16, 1);
1248 let input = Variable::new(
1249 Tensor::from_vec(vec![1.0; 2 * 10 * 8], &[2, 10, 8]).unwrap(),
1250 false,
1251 );
1252 let output = lstm.forward(&input);
1255 assert_eq!(output.shape(), vec![2, 10, 16]);
1256
1257 let out_vec = output.data().to_vec();
1259 let last_t0 = &out_vec[9 * 16..10 * 16]; assert!(
1261 last_t0.iter().all(|v| v.is_finite()),
1262 "Last output should be finite"
1263 );
1264 }
1265
1266 #[test]
1267 fn test_lstm_gradient_flow() {
1268 let lstm = LSTM::new(4, 8, 1);
1269 let input = Variable::new(
1270 Tensor::from_vec(vec![0.5; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
1271 true,
1272 );
1273 let output = lstm.forward(&input);
1274 let loss = output.sum();
1275 loss.backward();
1276
1277 let input_grad = input
1278 .grad()
1279 .expect("Input should have gradient through LSTM");
1280 assert_eq!(input_grad.shape(), &[1, 3, 4]);
1281 assert!(
1282 input_grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1283 "LSTM should propagate gradients to input"
1284 );
1285
1286 let params = lstm.parameters();
1288 let grads_exist = params.iter().any(|p| {
1289 p.grad()
1290 .map(|g| g.to_vec().iter().any(|v| v.abs() > 0.0))
1291 .unwrap_or(false)
1292 });
1293 assert!(grads_exist, "LSTM parameters should have gradients");
1294 }
1295
1296 #[test]
1297 fn test_lstm_different_sequence_lengths() {
1298 let lstm = LSTM::new(4, 8, 1);
1299
1300 let short = Variable::new(
1302 Tensor::from_vec(vec![1.0; 1 * 2 * 4], &[1, 2, 4]).unwrap(),
1303 false,
1304 );
1305 let out_short = lstm.forward(&short);
1306 assert_eq!(out_short.shape(), vec![1, 2, 8]);
1307
1308 let long = Variable::new(
1310 Tensor::from_vec(vec![1.0; 1 * 20 * 4], &[1, 20, 4]).unwrap(),
1311 false,
1312 );
1313 let out_long = lstm.forward(&long);
1314 assert_eq!(out_long.shape(), vec![1, 20, 8]);
1315 }
1316
1317 #[test]
1318 fn test_lstm_parameters_count() {
1319 let lstm = LSTM::new(10, 20, 1);
1322 let n = lstm.parameters().iter().map(|p| p.numel()).sum::<usize>();
1323 assert!(n > 0, "LSTM should have parameters");
1325 }
1326
1327 #[test]
1332 fn test_gru_cell_forward_step() {
1333 let cell = GRUCell::new(8, 16);
1334 assert_eq!(cell.input_size(), 8);
1335 assert_eq!(cell.hidden_size(), 16);
1336
1337 let input = Variable::new(Tensor::from_vec(vec![1.0; 2 * 8], &[2, 8]).unwrap(), false);
1338 let hidden = Variable::new(
1339 Tensor::from_vec(vec![0.0; 2 * 16], &[2, 16]).unwrap(),
1340 false,
1341 );
1342 let output = cell.forward_step(&input, &hidden);
1343 assert_eq!(output.shape(), vec![2, 16]);
1344 }
1345
1346 #[test]
1347 fn test_gru_multi_layer() {
1348 let gru = GRU::new(8, 16, 2);
1349 assert_eq!(gru.num_layers(), 2);
1350 assert_eq!(gru.hidden_size(), 16);
1351
1352 let input = Variable::new(
1353 Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1354 false,
1355 );
1356 let output = gru.forward(&input);
1357 assert_eq!(output.shape(), vec![2, 5, 16]);
1358 }
1359
1360 #[test]
1361 fn test_gru_forward_mean() {
1362 let gru = GRU::new(4, 8, 1);
1363 let input = Variable::new(
1364 Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1365 false,
1366 );
1367 let mean_out = gru.forward_mean(&input);
1368 assert_eq!(mean_out.shape(), vec![2, 8]);
1370 }
1371
1372 #[test]
1373 fn test_gru_forward_last() {
1374 let gru = GRU::new(4, 8, 1);
1375 let input = Variable::new(
1376 Tensor::from_vec(vec![1.0; 2 * 5 * 4], &[2, 5, 4]).unwrap(),
1377 false,
1378 );
1379 let last_out = gru.forward_last(&input);
1380 assert_eq!(last_out.shape(), vec![2, 8]);
1382 }
1383
1384 #[test]
1385 fn test_gru_gradient_flow_to_input() {
1386 let gru = GRU::new(4, 8, 1);
1387 let input = Variable::new(
1388 Tensor::from_vec(vec![0.5; 1 * 3 * 4], &[1, 3, 4]).unwrap(),
1389 true,
1390 );
1391 let output = gru.forward(&input);
1392 output.sum().backward();
1393
1394 let grad = input
1395 .grad()
1396 .expect("Input should have gradient through GRU");
1397 assert_eq!(grad.shape(), &[1, 3, 4]);
1398 assert!(
1399 grad.to_vec().iter().any(|g| g.abs() > 1e-10),
1400 "GRU should propagate gradients"
1401 );
1402 }
1403
1404 #[test]
1405 fn test_gru_hidden_state_evolves() {
1406 let gru = GRU::new(4, 8, 1);
1407 let input = Variable::new(
1408 Tensor::from_vec(vec![1.0; 1 * 5 * 4], &[1, 5, 4]).unwrap(),
1409 false,
1410 );
1411 let output = gru.forward(&input);
1412 let out_vec = output.data().to_vec();
1413
1414 let t0 = &out_vec[0..8];
1416 let t4 = &out_vec[4 * 8..5 * 8];
1417 let diff: f32 = t0.iter().zip(t4.iter()).map(|(a, b)| (a - b).abs()).sum();
1418 assert!(
1419 diff > 1e-6,
1420 "GRU hidden state should evolve over time, diff={}",
1421 diff
1422 );
1423 }
1424
1425 #[test]
1430 fn test_rnn_cell_gradient_flow() {
1431 let cell = RNNCell::new(4, 8);
1432 let input = Variable::new(Tensor::from_vec(vec![1.0; 1 * 4], &[1, 4]).unwrap(), true);
1433 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 1 * 8], &[1, 8]).unwrap(), false);
1434 let out = cell.forward_step(&input, &hidden);
1435 out.sum().backward();
1436
1437 let grad = input.grad().expect("RNNCell should propagate gradients");
1438 assert_eq!(grad.shape(), &[1, 4]);
1439 }
1440
1441 #[test]
1442 fn test_rnn_multi_layer() {
1443 let rnn = RNN::with_options(8, 16, 3, true); let input = Variable::new(
1445 Tensor::from_vec(vec![0.5; 2 * 5 * 8], &[2, 5, 8]).unwrap(),
1446 false,
1447 );
1448 let output = rnn.forward(&input);
1449 assert_eq!(output.shape(), vec![2, 5, 16]);
1450 }
1451
1452 #[test]
1457 fn test_lstm_outputs_are_bounded() {
1458 let lstm = LSTM::new(4, 8, 1);
1460 let input = Variable::new(
1461 Tensor::from_vec(vec![100.0; 1 * 10 * 4], &[1, 10, 4]).unwrap(),
1462 false,
1463 );
1464 let output = lstm.forward(&input);
1465 let out_vec = output.data().to_vec();
1466
1467 for v in &out_vec {
1469 assert!(v.is_finite(), "LSTM output should be finite, got {}", v);
1470 assert!(
1471 v.abs() <= 1.0 + 1e-5,
1472 "LSTM output should be bounded by tanh: got {}",
1473 v
1474 );
1475 }
1476 }
1477
1478 #[test]
1479 fn test_gru_outputs_finite_with_large_input() {
1480 let gru = GRU::new(4, 8, 1);
1481 let input = Variable::new(
1482 Tensor::from_vec(vec![50.0; 1 * 5 * 4], &[1, 5, 4]).unwrap(),
1483 false,
1484 );
1485 let output = gru.forward(&input);
1486 assert!(
1487 output.data().to_vec().iter().all(|v| v.is_finite()),
1488 "GRU should produce finite outputs for large inputs"
1489 );
1490 }
1491}