1use ferrotorch_core::grad_fns::activation::{relu, sigmoid, tanh};
32use ferrotorch_core::grad_fns::arithmetic::{add, mul, sub};
33use ferrotorch_core::grad_fns::shape::{cat, reshape};
34use ferrotorch_core::grad_fns::linalg::mm_differentiable as mm;
40use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
41
42use crate::init;
43use crate::module::Module;
44use crate::parameter::Parameter;
45
46type LstmOutput<T> = (Tensor<T>, (Tensor<T>, Tensor<T>));
48
49#[derive(Debug, Clone)]
55struct LSTMLayerParams<T: Float> {
56 weight_ih: Parameter<T>,
58 weight_hh: Parameter<T>,
60 bias_ih: Parameter<T>,
62 bias_hh: Parameter<T>,
64}
65
66#[derive(Debug)]
91pub struct LSTM<T: Float> {
92 input_size: usize,
93 hidden_size: usize,
94 num_layers: usize,
95 layers: Vec<LSTMLayerParams<T>>,
96 training: bool,
97}
98
99impl<T: Float> LSTM<T> {
100 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> FerrotorchResult<Self> {
113 if num_layers == 0 {
114 return Err(FerrotorchError::InvalidArgument {
115 message: "LSTM: num_layers must be >= 1".into(),
116 });
117 }
118 if hidden_size == 0 {
119 return Err(FerrotorchError::InvalidArgument {
120 message: "LSTM: hidden_size must be >= 1".into(),
121 });
122 }
123 if input_size == 0 {
124 return Err(FerrotorchError::InvalidArgument {
125 message: "LSTM: input_size must be >= 1".into(),
126 });
127 }
128
129 let k = 1.0 / (hidden_size as f64).sqrt();
130 let gate_size = 4 * hidden_size;
131
132 let mut layers = Vec::with_capacity(num_layers);
133
134 for layer_idx in 0..num_layers {
135 let layer_input_size = if layer_idx == 0 {
136 input_size
137 } else {
138 hidden_size
139 };
140
141 let mut weight_ih = Parameter::zeros(&[gate_size, layer_input_size])?;
142 let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
143 let mut bias_ih = Parameter::zeros(&[gate_size])?;
144 let mut bias_hh = Parameter::zeros(&[gate_size])?;
145
146 init::uniform(&mut weight_ih, -k, k)?;
147 init::uniform(&mut weight_hh, -k, k)?;
148 init::zeros(&mut bias_ih)?;
149 init::zeros(&mut bias_hh)?;
150
151 layers.push(LSTMLayerParams {
152 weight_ih,
153 weight_hh,
154 bias_ih,
155 bias_hh,
156 });
157 }
158
159 Ok(Self {
160 input_size,
161 hidden_size,
162 num_layers,
163 layers,
164 training: true,
165 })
166 }
167
168 pub fn forward_with_state(
182 &self,
183 input: &Tensor<T>,
184 state: Option<(&Tensor<T>, &Tensor<T>)>,
185 ) -> FerrotorchResult<LstmOutput<T>> {
186 if input.ndim() != 3 {
188 return Err(FerrotorchError::InvalidArgument {
189 message: format!(
190 "LSTM: expected 3-D input [batch, seq_len, input_size], got shape {:?}",
191 input.shape()
192 ),
193 });
194 }
195
196 let batch = input.shape()[0];
197 let seq_len = input.shape()[1];
198
199 if input.shape()[2] != self.input_size {
200 return Err(FerrotorchError::ShapeMismatch {
201 message: format!(
202 "LSTM: input_size mismatch: expected {}, got {}",
203 self.input_size,
204 input.shape()[2]
205 ),
206 });
207 }
208
209 let (h_init, c_init) = match state {
211 Some((h0, c0)) => {
212 let expected_shape = [self.num_layers, batch, self.hidden_size];
214 if h0.shape() != expected_shape {
215 return Err(FerrotorchError::ShapeMismatch {
216 message: format!(
217 "LSTM: h_0 shape mismatch: expected {:?}, got {:?}",
218 expected_shape,
219 h0.shape()
220 ),
221 });
222 }
223 if c0.shape() != expected_shape {
224 return Err(FerrotorchError::ShapeMismatch {
225 message: format!(
226 "LSTM: c_0 shape mismatch: expected {:?}, got {:?}",
227 expected_shape,
228 c0.shape()
229 ),
230 });
231 }
232 (h0.clone(), c0.clone())
233 }
234 None => {
235 let init_shape = [self.num_layers, batch, self.hidden_size];
240 let h0 = ferrotorch_core::zeros::<T>(&init_shape)?.to(input.device())?;
241 let c0 = ferrotorch_core::zeros::<T>(&init_shape)?.to(input.device())?;
242 (h0, c0)
243 }
244 };
245
246 let mut timestep_inputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
251 for t in 0..seq_len {
252 let slice = input.narrow(1, t, 1)?; timestep_inputs.push(slice.squeeze_t(1)?); }
255
256 let hs = self.hidden_size;
260 let mut layer_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
261 let mut layer_c: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
262 for l in 0..self.num_layers {
263 layer_h.push(h_init.narrow(0, l, 1)?.squeeze_t(0)?);
264 layer_c.push(c_init.narrow(0, l, 1)?.squeeze_t(0)?);
265 }
266
267 let mut layer_outputs: Vec<Tensor<T>> = timestep_inputs;
271 let mut final_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
272 let mut final_c: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
273
274 for (l, params) in self.layers.iter().enumerate() {
275 let mut h = layer_h[l].clone();
276 let mut c = layer_c[l].clone();
277 let mut next_layer_outputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
278
279 let wih_t = transpose_2d(params.weight_ih.tensor())?.contiguous()?;
306 let whh_t = transpose_2d(params.weight_hh.tensor())?.contiguous()?;
307
308 let bias_ih_2d = broadcast_bias_to_batch(¶ms.bias_ih, batch)?;
330 let bias_hh_2d = broadcast_bias_to_batch(¶ms.bias_hh, batch)?;
331 let xw_all = batched_input_projection(&layer_outputs, &wih_t)?;
332
333 for (t, _x_t) in layer_outputs.iter().enumerate() {
334 let xw = xw_all.narrow(0, t * batch, batch)?; let hw = mm(&h, &whh_t)?; let gates = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
341
342 let gate_chunks = gates.chunk(4, 1)?;
345 let i_pre = gate_chunks[0].clone();
346 let f_pre = gate_chunks[1].clone();
347 let g_pre = gate_chunks[2].clone();
348 let o_pre = gate_chunks[3].clone();
349
350 let i_gate = sigmoid(&i_pre)?;
352 let f_gate = sigmoid(&f_pre)?;
353 let g_gate = tanh(&g_pre)?;
354 let o_gate = sigmoid(&o_pre)?;
355
356 let fc = mul(&f_gate, &c)?;
358 let ig = mul(&i_gate, &g_gate)?;
359 let c_new = add(&fc, &ig)?;
360
361 let tanh_c = tanh(&c_new)?;
363 let h_new = mul(&o_gate, &tanh_c)?;
364
365 next_layer_outputs.push(h_new.clone());
366 h = h_new;
367 c = c_new;
368 }
369
370 final_h.push(h);
371 final_c.push(c);
372 layer_outputs = next_layer_outputs;
373 }
374
375 let output = if seq_len == 1 {
383 reshape(&layer_outputs[0], &[batch as isize, 1, hs as isize])?
385 } else {
386 let stacked = cat(&layer_outputs, 1)?;
388 reshape(&stacked, &[batch as isize, seq_len as isize, hs as isize])?
390 };
391
392 let h_n = if self.num_layers == 1 {
396 reshape(&final_h[0], &[1, batch as isize, hs as isize])?
397 } else {
398 let h_stacked = cat(&final_h, 0)?;
399 reshape(
400 &h_stacked,
401 &[self.num_layers as isize, batch as isize, hs as isize],
402 )?
403 };
404 let c_n = if self.num_layers == 1 {
405 reshape(&final_c[0], &[1, batch as isize, hs as isize])?
406 } else {
407 let c_stacked = cat(&final_c, 0)?;
408 reshape(
409 &c_stacked,
410 &[self.num_layers as isize, batch as isize, hs as isize],
411 )?
412 };
413
414 Ok((output, (h_n, c_n)))
415 }
416
417 #[inline]
419 pub fn input_size(&self) -> usize {
420 self.input_size
421 }
422
423 #[inline]
425 pub fn hidden_size(&self) -> usize {
426 self.hidden_size
427 }
428
429 #[inline]
431 pub fn num_layers(&self) -> usize {
432 self.num_layers
433 }
434}
435
436impl<T: Float> Module<T> for LSTM<T> {
441 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
448 let (output, _) = self.forward_with_state(input, None)?;
449 Ok(output)
450 }
451
452 fn parameters(&self) -> Vec<&Parameter<T>> {
453 let mut params = Vec::with_capacity(self.num_layers * 4);
454 for layer in &self.layers {
455 params.push(&layer.weight_ih);
456 params.push(&layer.weight_hh);
457 params.push(&layer.bias_ih);
458 params.push(&layer.bias_hh);
459 }
460 params
461 }
462
463 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
464 let mut params = Vec::with_capacity(self.num_layers * 4);
465 for layer in &mut self.layers {
466 params.push(&mut layer.weight_ih);
467 params.push(&mut layer.weight_hh);
468 params.push(&mut layer.bias_ih);
469 params.push(&mut layer.bias_hh);
470 }
471 params
472 }
473
474 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
475 let mut params = Vec::with_capacity(self.num_layers * 4);
476 for (i, layer) in self.layers.iter().enumerate() {
477 params.push((format!("layers.{i}.weight_ih"), &layer.weight_ih));
478 params.push((format!("layers.{i}.weight_hh"), &layer.weight_hh));
479 params.push((format!("layers.{i}.bias_ih"), &layer.bias_ih));
480 params.push((format!("layers.{i}.bias_hh"), &layer.bias_hh));
481 }
482 params
483 }
484
485 fn train(&mut self) {
486 self.training = true;
487 }
488
489 fn eval(&mut self) {
490 self.training = false;
491 }
492
493 fn is_training(&self) -> bool {
494 self.training
495 }
496}
497
498fn transpose_2d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
510 ferrotorch_core::grad_fns::shape::transpose_2d(input)
511}
512
513fn batched_input_projection<T: Float>(
536 step_inputs: &[Tensor<T>],
537 wih_t: &Tensor<T>,
538) -> FerrotorchResult<Tensor<T>> {
539 if step_inputs.len() == 1 {
540 return mm(&step_inputs[0], wih_t);
542 }
543 let x_all = cat(step_inputs, 0)?;
546 mm(&x_all, wih_t) }
548
549fn broadcast_bias_to_batch<T: Float>(
555 bias: &Parameter<T>,
556 batch: usize,
557) -> FerrotorchResult<Tensor<T>> {
558 let n = bias.tensor().shape()[0];
559 let bias_2d = bias.tensor().unsqueeze_t(0)?; ferrotorch_core::grad_fns::shape::expand(&bias_2d, &[batch, n])
561}
562
563#[derive(Debug, Clone)]
569struct GRULayerParams<T: Float> {
570 weight_ih: Parameter<T>,
572 weight_hh: Parameter<T>,
574 bias_ih: Parameter<T>,
576 bias_hh: Parameter<T>,
578}
579
580#[derive(Debug)]
603pub struct GRU<T: Float> {
604 input_size: usize,
605 hidden_size: usize,
606 num_layers: usize,
607 layers: Vec<GRULayerParams<T>>,
608 training: bool,
609}
610
611impl<T: Float> GRU<T> {
612 pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
627 Self::with_num_layers(input_size, hidden_size, 1)
628 }
629
630 pub fn with_num_layers(
638 input_size: usize,
639 hidden_size: usize,
640 num_layers: usize,
641 ) -> FerrotorchResult<Self> {
642 if num_layers == 0 {
643 return Err(FerrotorchError::InvalidArgument {
644 message: "GRU: num_layers must be >= 1".into(),
645 });
646 }
647 if hidden_size == 0 {
648 return Err(FerrotorchError::InvalidArgument {
649 message: "GRU: hidden_size must be >= 1".into(),
650 });
651 }
652 if input_size == 0 {
653 return Err(FerrotorchError::InvalidArgument {
654 message: "GRU: input_size must be >= 1".into(),
655 });
656 }
657
658 let k = 1.0 / (hidden_size as f64).sqrt();
659 let gate_size = 3 * hidden_size;
660
661 let mut layers = Vec::with_capacity(num_layers);
662
663 for layer_idx in 0..num_layers {
664 let layer_input_size = if layer_idx == 0 {
665 input_size
666 } else {
667 hidden_size
668 };
669
670 let mut weight_ih = Parameter::zeros(&[gate_size, layer_input_size])?;
671 let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
672 let mut bias_ih = Parameter::zeros(&[gate_size])?;
673 let mut bias_hh = Parameter::zeros(&[gate_size])?;
674
675 init::uniform(&mut weight_ih, -k, k)?;
676 init::uniform(&mut weight_hh, -k, k)?;
677 init::zeros(&mut bias_ih)?;
678 init::zeros(&mut bias_hh)?;
679
680 layers.push(GRULayerParams {
681 weight_ih,
682 weight_hh,
683 bias_ih,
684 bias_hh,
685 });
686 }
687
688 Ok(Self {
689 input_size,
690 hidden_size,
691 num_layers,
692 layers,
693 training: true,
694 })
695 }
696
697 pub fn forward(
711 &self,
712 input: &Tensor<T>,
713 h_0: Option<&Tensor<T>>,
714 ) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
715 if input.ndim() != 3 {
717 return Err(FerrotorchError::InvalidArgument {
718 message: format!(
719 "GRU: expected 3-D input [batch, seq_len, input_size], got shape {:?}",
720 input.shape()
721 ),
722 });
723 }
724
725 let batch = input.shape()[0];
726 let seq_len = input.shape()[1];
727
728 if input.shape()[2] != self.input_size {
729 return Err(FerrotorchError::ShapeMismatch {
730 message: format!(
731 "GRU: input_size mismatch: expected {}, got {}",
732 self.input_size,
733 input.shape()[2]
734 ),
735 });
736 }
737
738 let h_init = match h_0 {
740 Some(h0) => {
741 let expected_shape = [self.num_layers, batch, self.hidden_size];
742 if h0.shape() != expected_shape {
743 return Err(FerrotorchError::ShapeMismatch {
744 message: format!(
745 "GRU: h_0 shape mismatch: expected {:?}, got {:?}",
746 expected_shape,
747 h0.shape()
748 ),
749 });
750 }
751 h0.clone()
752 }
753 None => ferrotorch_core::zeros::<T>(&[self.num_layers, batch, self.hidden_size])?,
754 };
755
756 let hs = self.hidden_size;
759 let mut timestep_inputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
760 for t in 0..seq_len {
761 let slice = input.narrow(1, t, 1)?; timestep_inputs.push(slice.squeeze_t(1)?); }
764
765 let mut layer_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
768 for l in 0..self.num_layers {
769 layer_h.push(h_init.narrow(0, l, 1)?.squeeze_t(0)?);
770 }
771
772 let mut layer_outputs: Vec<Tensor<T>> = timestep_inputs;
774 let mut final_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
775
776 let is_f32 = std::mem::size_of::<T>() == 4;
777
778 for (l, params) in self.layers.iter().enumerate() {
779 let mut h = layer_h[l].clone();
780 let mut next_layer_outputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
781
782 let wih_t = ferrotorch_core::grad_fns::shape::transpose_2d(params.weight_ih.tensor())?
789 .contiguous()?;
790 let whh_t = ferrotorch_core::grad_fns::shape::transpose_2d(params.weight_hh.tensor())?
791 .contiguous()?;
792
793 let xw_all = batched_input_projection(&layer_outputs, &wih_t)?;
802
803 let use_fused_gpu =
805 is_f32 && h.is_cuda() && ferrotorch_core::gpu_dispatch::gpu_backend().is_some();
806
807 for (t, _x_t) in layer_outputs.iter().enumerate() {
808 let xw = xw_all.narrow(0, t * batch, batch)?; let hw = mm(&h, &whh_t)?; if use_fused_gpu {
814 let xw_c = xw.contiguous()?;
824 let backend = ferrotorch_core::gpu_dispatch::gpu_backend()
825 .ok_or(FerrotorchError::DeviceUnavailable)?;
826 let (hy_handle, _workspace) = backend.fused_gru_cell_f32(
827 xw_c.gpu_handle()?,
828 hw.gpu_handle()?,
829 params.bias_ih.tensor().gpu_handle()?,
830 params.bias_hh.tensor().gpu_handle()?,
831 h.gpu_handle()?,
832 hs,
833 )?;
834 let h_new = Tensor::from_storage(
835 TensorStorage::gpu(hy_handle),
836 vec![batch, hs],
837 false,
838 )?;
839 next_layer_outputs.push(h_new.clone());
840 h = h_new;
841 } else {
842 let bias_ih_2d = broadcast_bias_to_batch(¶ms.bias_ih, batch)?;
848 let bias_hh_2d = broadcast_bias_to_batch(¶ms.bias_hh, batch)?;
849
850 let xw_b = add(&xw, &bias_ih_2d)?;
851 let hw_b = add(&hw, &bias_hh_2d)?;
852
853 let xw_chunks = xw_b.chunk(3, 1)?;
855 let hw_chunks = hw_b.chunk(3, 1)?;
856 let rx = xw_chunks[0].clone();
857 let zx = xw_chunks[1].clone();
858 let nx = xw_chunks[2].clone();
859 let rh = hw_chunks[0].clone();
860 let zh = hw_chunks[1].clone();
861 let nh = hw_chunks[2].clone();
862
863 let r_gate = sigmoid(&add(&rx, &rh)?)?;
864 let z_gate = sigmoid(&add(&zx, &zh)?)?;
865 let r_nh = mul(&r_gate, &nh)?;
866 let n_gate = tanh(&add(&nx, &r_nh)?)?;
867 let h_minus_n = sub(&h, &n_gate)?;
868 let z_h_minus_n = mul(&z_gate, &h_minus_n)?;
869 let h_new = add(&n_gate, &z_h_minus_n)?;
870
871 next_layer_outputs.push(h_new.clone());
872 h = h_new;
873 }
874 }
875
876 final_h.push(h);
877 layer_outputs = next_layer_outputs;
878 }
879
880 let output = if seq_len == 1 {
885 reshape(&layer_outputs[0], &[batch as isize, 1, hs as isize])?
886 } else {
887 let stacked = cat(&layer_outputs, 1)?;
888 reshape(&stacked, &[batch as isize, seq_len as isize, hs as isize])?
889 };
890
891 let h_n = if self.num_layers == 1 {
895 reshape(&final_h[0], &[1, batch as isize, hs as isize])?
896 } else {
897 let h_stacked = cat(&final_h, 0)?;
898 reshape(
899 &h_stacked,
900 &[self.num_layers as isize, batch as isize, hs as isize],
901 )?
902 };
903
904 Ok((output, h_n))
905 }
906
907 #[inline]
909 pub fn input_size(&self) -> usize {
910 self.input_size
911 }
912
913 #[inline]
915 pub fn hidden_size(&self) -> usize {
916 self.hidden_size
917 }
918
919 #[inline]
921 pub fn num_layers(&self) -> usize {
922 self.num_layers
923 }
924}
925
926impl<T: Float> Module<T> for GRU<T> {
931 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
938 let (output, _) = GRU::forward(self, input, None)?;
939 Ok(output)
940 }
941
942 fn parameters(&self) -> Vec<&Parameter<T>> {
943 let mut params = Vec::with_capacity(self.num_layers * 4);
944 for layer in &self.layers {
945 params.push(&layer.weight_ih);
946 params.push(&layer.weight_hh);
947 params.push(&layer.bias_ih);
948 params.push(&layer.bias_hh);
949 }
950 params
951 }
952
953 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
954 let mut params = Vec::with_capacity(self.num_layers * 4);
955 for layer in &mut self.layers {
956 params.push(&mut layer.weight_ih);
957 params.push(&mut layer.weight_hh);
958 params.push(&mut layer.bias_ih);
959 params.push(&mut layer.bias_hh);
960 }
961 params
962 }
963
964 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
965 let mut params = Vec::with_capacity(self.num_layers * 4);
966 for (i, layer) in self.layers.iter().enumerate() {
967 params.push((format!("layers.{i}.weight_ih"), &layer.weight_ih));
968 params.push((format!("layers.{i}.weight_hh"), &layer.weight_hh));
969 params.push((format!("layers.{i}.bias_ih"), &layer.bias_ih));
970 params.push((format!("layers.{i}.bias_hh"), &layer.bias_hh));
971 }
972 params
973 }
974
975 fn train(&mut self) {
976 self.training = true;
977 }
978
979 fn eval(&mut self) {
980 self.training = false;
981 }
982
983 fn is_training(&self) -> bool {
984 self.training
985 }
986}
987
988#[derive(Debug, Clone, Copy, PartialEq, Eq)]
994pub enum RNNNonlinearity {
995 Tanh,
997 ReLU,
999}
1000
1001#[derive(Debug)]
1007pub struct RNNCell<T: Float> {
1008 input_size: usize,
1009 hidden_size: usize,
1010 nonlinearity: RNNNonlinearity,
1011 weight_ih: Parameter<T>,
1012 weight_hh: Parameter<T>,
1013 bias_ih: Parameter<T>,
1014 bias_hh: Parameter<T>,
1015 training: bool,
1016}
1017
1018impl<T: Float> RNNCell<T> {
1019 pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1029 Self::with_nonlinearity(input_size, hidden_size, RNNNonlinearity::Tanh)
1030 }
1031
1032 pub fn with_nonlinearity(
1034 input_size: usize,
1035 hidden_size: usize,
1036 nonlinearity: RNNNonlinearity,
1037 ) -> FerrotorchResult<Self> {
1038 if hidden_size == 0 {
1039 return Err(FerrotorchError::InvalidArgument {
1040 message: "RNNCell: hidden_size must be >= 1".into(),
1041 });
1042 }
1043 if input_size == 0 {
1044 return Err(FerrotorchError::InvalidArgument {
1045 message: "RNNCell: input_size must be >= 1".into(),
1046 });
1047 }
1048
1049 let k = 1.0 / (hidden_size as f64).sqrt();
1050
1051 let mut weight_ih = Parameter::zeros(&[hidden_size, input_size])?;
1052 let mut weight_hh = Parameter::zeros(&[hidden_size, hidden_size])?;
1053 let mut bias_ih = Parameter::zeros(&[hidden_size])?;
1054 let mut bias_hh = Parameter::zeros(&[hidden_size])?;
1055
1056 init::uniform(&mut weight_ih, -k, k)?;
1057 init::uniform(&mut weight_hh, -k, k)?;
1058 init::zeros(&mut bias_ih)?;
1059 init::zeros(&mut bias_hh)?;
1060
1061 Ok(Self {
1062 input_size,
1063 hidden_size,
1064 nonlinearity,
1065 weight_ih,
1066 weight_hh,
1067 bias_ih,
1068 bias_hh,
1069 training: true,
1070 })
1071 }
1072
1073 pub fn forward_cell(
1085 &self,
1086 input: &Tensor<T>,
1087 h: Option<&Tensor<T>>,
1088 ) -> FerrotorchResult<Tensor<T>> {
1089 if input.ndim() != 2 {
1090 return Err(FerrotorchError::InvalidArgument {
1091 message: format!(
1092 "RNNCell: expected 2-D input [batch, input_size], got shape {:?}",
1093 input.shape()
1094 ),
1095 });
1096 }
1097 let batch = input.shape()[0];
1098 if input.shape()[1] != self.input_size {
1099 return Err(FerrotorchError::ShapeMismatch {
1100 message: format!(
1101 "RNNCell: input_size mismatch: expected {}, got {}",
1102 self.input_size,
1103 input.shape()[1]
1104 ),
1105 });
1106 }
1107
1108 let h_state = match h {
1109 Some(h0) => {
1110 if h0.shape() != [batch, self.hidden_size] {
1111 return Err(FerrotorchError::ShapeMismatch {
1112 message: format!(
1113 "RNNCell: h shape mismatch: expected {:?}, got {:?}",
1114 [batch, self.hidden_size],
1115 h0.shape()
1116 ),
1117 });
1118 }
1119 h0.clone()
1120 }
1121 None => ferrotorch_core::zeros::<T>(&[batch, self.hidden_size])?,
1122 };
1123
1124 let wih_t = transpose_2d(self.weight_ih.tensor())?;
1125 let whh_t = transpose_2d(self.weight_hh.tensor())?;
1126
1127 let xw = mm(input, &wih_t)?; let hw = mm(&h_state, &whh_t)?; let bias_ih_2d = broadcast_bias_to_batch(&self.bias_ih, batch)?;
1131 let bias_hh_2d = broadcast_bias_to_batch(&self.bias_hh, batch)?;
1132
1133 let pre_act = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
1134
1135 match self.nonlinearity {
1136 RNNNonlinearity::Tanh => tanh(&pre_act),
1137 RNNNonlinearity::ReLU => relu(&pre_act),
1138 }
1139 }
1140
1141 #[inline]
1143 pub fn input_size(&self) -> usize {
1144 self.input_size
1145 }
1146
1147 #[inline]
1149 pub fn hidden_size(&self) -> usize {
1150 self.hidden_size
1151 }
1152
1153 #[inline]
1155 pub fn nonlinearity(&self) -> RNNNonlinearity {
1156 self.nonlinearity
1157 }
1158}
1159
1160impl<T: Float> Module<T> for RNNCell<T> {
1161 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1163 self.forward_cell(input, None)
1164 }
1165
1166 fn parameters(&self) -> Vec<&Parameter<T>> {
1167 vec![
1168 &self.weight_ih,
1169 &self.weight_hh,
1170 &self.bias_ih,
1171 &self.bias_hh,
1172 ]
1173 }
1174
1175 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1176 vec![
1177 &mut self.weight_ih,
1178 &mut self.weight_hh,
1179 &mut self.bias_ih,
1180 &mut self.bias_hh,
1181 ]
1182 }
1183
1184 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1185 vec![
1186 ("weight_ih".into(), &self.weight_ih),
1187 ("weight_hh".into(), &self.weight_hh),
1188 ("bias_ih".into(), &self.bias_ih),
1189 ("bias_hh".into(), &self.bias_hh),
1190 ]
1191 }
1192
1193 fn train(&mut self) {
1194 self.training = true;
1195 }
1196
1197 fn eval(&mut self) {
1198 self.training = false;
1199 }
1200
1201 fn is_training(&self) -> bool {
1202 self.training
1203 }
1204}
1205
1206#[derive(Debug)]
1225pub struct LSTMCell<T: Float> {
1226 input_size: usize,
1227 hidden_size: usize,
1228 weight_ih: Parameter<T>,
1229 weight_hh: Parameter<T>,
1230 bias_ih: Parameter<T>,
1231 bias_hh: Parameter<T>,
1232 training: bool,
1233}
1234
1235impl<T: Float> LSTMCell<T> {
1236 pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1243 if hidden_size == 0 {
1244 return Err(FerrotorchError::InvalidArgument {
1245 message: "LSTMCell: hidden_size must be >= 1".into(),
1246 });
1247 }
1248 if input_size == 0 {
1249 return Err(FerrotorchError::InvalidArgument {
1250 message: "LSTMCell: input_size must be >= 1".into(),
1251 });
1252 }
1253
1254 let k = 1.0 / (hidden_size as f64).sqrt();
1255 let gate_size = 4 * hidden_size;
1256
1257 let mut weight_ih = Parameter::zeros(&[gate_size, input_size])?;
1258 let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
1259 let mut bias_ih = Parameter::zeros(&[gate_size])?;
1260 let mut bias_hh = Parameter::zeros(&[gate_size])?;
1261
1262 init::uniform(&mut weight_ih, -k, k)?;
1263 init::uniform(&mut weight_hh, -k, k)?;
1264 init::zeros(&mut bias_ih)?;
1265 init::zeros(&mut bias_hh)?;
1266
1267 Ok(Self {
1268 input_size,
1269 hidden_size,
1270 weight_ih,
1271 weight_hh,
1272 bias_ih,
1273 bias_hh,
1274 training: true,
1275 })
1276 }
1277
1278 pub fn forward_cell(
1290 &self,
1291 input: &Tensor<T>,
1292 state: Option<(&Tensor<T>, &Tensor<T>)>,
1293 ) -> FerrotorchResult<(Tensor<T>, Tensor<T>)> {
1294 if input.ndim() != 2 {
1295 return Err(FerrotorchError::InvalidArgument {
1296 message: format!(
1297 "LSTMCell: expected 2-D input [batch, input_size], got shape {:?}",
1298 input.shape()
1299 ),
1300 });
1301 }
1302 let batch = input.shape()[0];
1303 if input.shape()[1] != self.input_size {
1304 return Err(FerrotorchError::ShapeMismatch {
1305 message: format!(
1306 "LSTMCell: input_size mismatch: expected {}, got {}",
1307 self.input_size,
1308 input.shape()[1]
1309 ),
1310 });
1311 }
1312
1313 let expected_h_shape = [batch, self.hidden_size];
1314
1315 let (h_state, c_state) = match state {
1316 Some((h0, c0)) => {
1317 if h0.shape() != expected_h_shape {
1318 return Err(FerrotorchError::ShapeMismatch {
1319 message: format!(
1320 "LSTMCell: h shape mismatch: expected {:?}, got {:?}",
1321 expected_h_shape,
1322 h0.shape()
1323 ),
1324 });
1325 }
1326 if c0.shape() != expected_h_shape {
1327 return Err(FerrotorchError::ShapeMismatch {
1328 message: format!(
1329 "LSTMCell: c shape mismatch: expected {:?}, got {:?}",
1330 expected_h_shape,
1331 c0.shape()
1332 ),
1333 });
1334 }
1335 (h0.clone(), c0.clone())
1336 }
1337 None => {
1338 let h0 = ferrotorch_core::zeros::<T>(&[batch, self.hidden_size])?;
1339 let c0 = ferrotorch_core::zeros::<T>(&[batch, self.hidden_size])?;
1340 (h0, c0)
1341 }
1342 };
1343
1344 let wih_t = transpose_2d(self.weight_ih.tensor())?;
1345 let whh_t = transpose_2d(self.weight_hh.tensor())?;
1346
1347 let xw = mm(input, &wih_t)?; let hw = mm(&h_state, &whh_t)?; let bias_ih_2d = broadcast_bias_to_batch(&self.bias_ih, batch)?;
1351 let bias_hh_2d = broadcast_bias_to_batch(&self.bias_hh, batch)?;
1352
1353 let gates = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
1354
1355 let gate_chunks = gates.chunk(4, 1)?;
1357 let i_gate = sigmoid(&gate_chunks[0])?;
1358 let f_gate = sigmoid(&gate_chunks[1])?;
1359 let g_gate = tanh(&gate_chunks[2])?;
1360 let o_gate = sigmoid(&gate_chunks[3])?;
1361
1362 let c_new = add(&mul(&f_gate, &c_state)?, &mul(&i_gate, &g_gate)?)?;
1364
1365 let h_new = mul(&o_gate, &tanh(&c_new)?)?;
1367
1368 Ok((h_new, c_new))
1369 }
1370
1371 #[inline]
1373 pub fn input_size(&self) -> usize {
1374 self.input_size
1375 }
1376
1377 #[inline]
1379 pub fn hidden_size(&self) -> usize {
1380 self.hidden_size
1381 }
1382}
1383
1384impl<T: Float> Module<T> for LSTMCell<T> {
1385 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1387 let (h, _c) = self.forward_cell(input, None)?;
1388 Ok(h)
1389 }
1390
1391 fn parameters(&self) -> Vec<&Parameter<T>> {
1392 vec![
1393 &self.weight_ih,
1394 &self.weight_hh,
1395 &self.bias_ih,
1396 &self.bias_hh,
1397 ]
1398 }
1399
1400 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1401 vec![
1402 &mut self.weight_ih,
1403 &mut self.weight_hh,
1404 &mut self.bias_ih,
1405 &mut self.bias_hh,
1406 ]
1407 }
1408
1409 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1410 vec![
1411 ("weight_ih".into(), &self.weight_ih),
1412 ("weight_hh".into(), &self.weight_hh),
1413 ("bias_ih".into(), &self.bias_ih),
1414 ("bias_hh".into(), &self.bias_hh),
1415 ]
1416 }
1417
1418 fn train(&mut self) {
1419 self.training = true;
1420 }
1421
1422 fn eval(&mut self) {
1423 self.training = false;
1424 }
1425
1426 fn is_training(&self) -> bool {
1427 self.training
1428 }
1429}
1430
1431#[derive(Debug)]
1447pub struct GRUCell<T: Float> {
1448 input_size: usize,
1449 hidden_size: usize,
1450 weight_ih: Parameter<T>,
1451 weight_hh: Parameter<T>,
1452 bias_ih: Parameter<T>,
1453 bias_hh: Parameter<T>,
1454 training: bool,
1455}
1456
1457impl<T: Float> GRUCell<T> {
1458 pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1465 if hidden_size == 0 {
1466 return Err(FerrotorchError::InvalidArgument {
1467 message: "GRUCell: hidden_size must be >= 1".into(),
1468 });
1469 }
1470 if input_size == 0 {
1471 return Err(FerrotorchError::InvalidArgument {
1472 message: "GRUCell: input_size must be >= 1".into(),
1473 });
1474 }
1475
1476 let k = 1.0 / (hidden_size as f64).sqrt();
1477 let gate_size = 3 * hidden_size;
1478
1479 let mut weight_ih = Parameter::zeros(&[gate_size, input_size])?;
1480 let mut weight_hh = Parameter::zeros(&[gate_size, hidden_size])?;
1481 let mut bias_ih = Parameter::zeros(&[gate_size])?;
1482 let mut bias_hh = Parameter::zeros(&[gate_size])?;
1483
1484 init::uniform(&mut weight_ih, -k, k)?;
1485 init::uniform(&mut weight_hh, -k, k)?;
1486 init::zeros(&mut bias_ih)?;
1487 init::zeros(&mut bias_hh)?;
1488
1489 Ok(Self {
1490 input_size,
1491 hidden_size,
1492 weight_ih,
1493 weight_hh,
1494 bias_ih,
1495 bias_hh,
1496 training: true,
1497 })
1498 }
1499
1500 pub fn forward_cell(
1512 &self,
1513 input: &Tensor<T>,
1514 h: Option<&Tensor<T>>,
1515 ) -> FerrotorchResult<Tensor<T>> {
1516 if input.ndim() != 2 {
1517 return Err(FerrotorchError::InvalidArgument {
1518 message: format!(
1519 "GRUCell: expected 2-D input [batch, input_size], got shape {:?}",
1520 input.shape()
1521 ),
1522 });
1523 }
1524 let batch = input.shape()[0];
1525 if input.shape()[1] != self.input_size {
1526 return Err(FerrotorchError::ShapeMismatch {
1527 message: format!(
1528 "GRUCell: input_size mismatch: expected {}, got {}",
1529 self.input_size,
1530 input.shape()[1]
1531 ),
1532 });
1533 }
1534
1535 let hs = self.hidden_size;
1536
1537 let h_state = match h {
1538 Some(h0) => {
1539 if h0.shape() != [batch, hs] {
1540 return Err(FerrotorchError::ShapeMismatch {
1541 message: format!(
1542 "GRUCell: h shape mismatch: expected {:?}, got {:?}",
1543 [batch, hs],
1544 h0.shape()
1545 ),
1546 });
1547 }
1548 h0.clone()
1549 }
1550 None => ferrotorch_core::zeros::<T>(&[batch, hs])?,
1551 };
1552
1553 let wih_t = transpose_2d(self.weight_ih.tensor())?;
1554 let whh_t = transpose_2d(self.weight_hh.tensor())?;
1555
1556 let xw = mm(input, &wih_t)?; let hw = mm(&h_state, &whh_t)?; let bias_ih_2d = broadcast_bias_to_batch(&self.bias_ih, batch)?;
1560 let bias_hh_2d = broadcast_bias_to_batch(&self.bias_hh, batch)?;
1561
1562 let xw_b = add(&xw, &bias_ih_2d)?;
1563 let hw_b = add(&hw, &bias_hh_2d)?;
1564
1565 let xw_chunks = xw_b.chunk(3, 1)?;
1568 let hw_chunks = hw_b.chunk(3, 1)?;
1569 let rx = xw_chunks[0].clone();
1570 let zx = xw_chunks[1].clone();
1571 let nx = xw_chunks[2].clone();
1572 let rh = hw_chunks[0].clone();
1573 let zh = hw_chunks[1].clone();
1574 let nh = hw_chunks[2].clone();
1575
1576 let r_gate = sigmoid(&add(&rx, &rh)?)?;
1578 let z_gate = sigmoid(&add(&zx, &zh)?)?;
1579
1580 let r_nh = mul(&r_gate, &nh)?;
1582 let n_gate = tanh(&add(&nx, &r_nh)?)?;
1583
1584 let h_minus_n = sub(&h_state, &n_gate)?;
1586 let z_h_minus_n = mul(&z_gate, &h_minus_n)?;
1587 add(&n_gate, &z_h_minus_n)
1588 }
1589
1590 #[inline]
1592 pub fn input_size(&self) -> usize {
1593 self.input_size
1594 }
1595
1596 #[inline]
1598 pub fn hidden_size(&self) -> usize {
1599 self.hidden_size
1600 }
1601}
1602
1603impl<T: Float> Module<T> for GRUCell<T> {
1604 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1606 self.forward_cell(input, None)
1607 }
1608
1609 fn parameters(&self) -> Vec<&Parameter<T>> {
1610 vec![
1611 &self.weight_ih,
1612 &self.weight_hh,
1613 &self.bias_ih,
1614 &self.bias_hh,
1615 ]
1616 }
1617
1618 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1619 vec![
1620 &mut self.weight_ih,
1621 &mut self.weight_hh,
1622 &mut self.bias_ih,
1623 &mut self.bias_hh,
1624 ]
1625 }
1626
1627 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1628 vec![
1629 ("weight_ih".into(), &self.weight_ih),
1630 ("weight_hh".into(), &self.weight_hh),
1631 ("bias_ih".into(), &self.bias_ih),
1632 ("bias_hh".into(), &self.bias_hh),
1633 ]
1634 }
1635
1636 fn train(&mut self) {
1637 self.training = true;
1638 }
1639
1640 fn eval(&mut self) {
1641 self.training = false;
1642 }
1643
1644 fn is_training(&self) -> bool {
1645 self.training
1646 }
1647}
1648
1649type RnnOutput<T> = (Tensor<T>, Tensor<T>);
1655
1656#[derive(Debug, Clone)]
1658struct RNNLayerParams<T: Float> {
1659 weight_ih: Parameter<T>,
1661 weight_hh: Parameter<T>,
1663 bias_ih: Parameter<T>,
1665 bias_hh: Parameter<T>,
1667}
1668
1669#[derive(Debug)]
1681pub struct RNN<T: Float> {
1682 input_size: usize,
1683 hidden_size: usize,
1684 num_layers: usize,
1685 nonlinearity: RNNNonlinearity,
1686 layers: Vec<RNNLayerParams<T>>,
1687 training: bool,
1688}
1689
1690impl<T: Float> RNN<T> {
1691 pub fn new(input_size: usize, hidden_size: usize) -> FerrotorchResult<Self> {
1693 Self::with_options(input_size, hidden_size, 1, RNNNonlinearity::Tanh)
1694 }
1695
1696 pub fn with_options(
1705 input_size: usize,
1706 hidden_size: usize,
1707 num_layers: usize,
1708 nonlinearity: RNNNonlinearity,
1709 ) -> FerrotorchResult<Self> {
1710 if num_layers == 0 {
1711 return Err(FerrotorchError::InvalidArgument {
1712 message: "RNN: num_layers must be >= 1".into(),
1713 });
1714 }
1715 if hidden_size == 0 {
1716 return Err(FerrotorchError::InvalidArgument {
1717 message: "RNN: hidden_size must be >= 1".into(),
1718 });
1719 }
1720 if input_size == 0 {
1721 return Err(FerrotorchError::InvalidArgument {
1722 message: "RNN: input_size must be >= 1".into(),
1723 });
1724 }
1725
1726 let k = 1.0 / (hidden_size as f64).sqrt();
1727
1728 let mut layers = Vec::with_capacity(num_layers);
1729
1730 for layer_idx in 0..num_layers {
1731 let layer_input_size = if layer_idx == 0 {
1732 input_size
1733 } else {
1734 hidden_size
1735 };
1736
1737 let mut weight_ih = Parameter::zeros(&[hidden_size, layer_input_size])?;
1738 let mut weight_hh = Parameter::zeros(&[hidden_size, hidden_size])?;
1739 let mut bias_ih = Parameter::zeros(&[hidden_size])?;
1740 let mut bias_hh = Parameter::zeros(&[hidden_size])?;
1741
1742 init::uniform(&mut weight_ih, -k, k)?;
1743 init::uniform(&mut weight_hh, -k, k)?;
1744 init::zeros(&mut bias_ih)?;
1745 init::zeros(&mut bias_hh)?;
1746
1747 layers.push(RNNLayerParams {
1748 weight_ih,
1749 weight_hh,
1750 bias_ih,
1751 bias_hh,
1752 });
1753 }
1754
1755 Ok(Self {
1756 input_size,
1757 hidden_size,
1758 num_layers,
1759 nonlinearity,
1760 layers,
1761 training: true,
1762 })
1763 }
1764
1765 pub fn forward_with_state(
1779 &self,
1780 input: &Tensor<T>,
1781 h_0: Option<&Tensor<T>>,
1782 ) -> FerrotorchResult<RnnOutput<T>> {
1783 if input.ndim() != 3 {
1784 return Err(FerrotorchError::InvalidArgument {
1785 message: format!(
1786 "RNN: expected 3-D input [batch, seq_len, input_size], got shape {:?}",
1787 input.shape()
1788 ),
1789 });
1790 }
1791
1792 let batch = input.shape()[0];
1793 let seq_len = input.shape()[1];
1794 let hs = self.hidden_size;
1795
1796 if input.shape()[2] != self.input_size {
1797 return Err(FerrotorchError::ShapeMismatch {
1798 message: format!(
1799 "RNN: input_size mismatch: expected {}, got {}",
1800 self.input_size,
1801 input.shape()[2]
1802 ),
1803 });
1804 }
1805
1806 let h_init = match h_0 {
1808 Some(h0) => {
1809 let expected_shape = [self.num_layers, batch, hs];
1810 if h0.shape() != expected_shape {
1811 return Err(FerrotorchError::ShapeMismatch {
1812 message: format!(
1813 "RNN: h_0 shape mismatch: expected {:?}, got {:?}",
1814 expected_shape,
1815 h0.shape()
1816 ),
1817 });
1818 }
1819 h0.clone()
1820 }
1821 None => ferrotorch_core::zeros::<T>(&[self.num_layers, batch, hs])?,
1822 };
1823
1824 let mut timestep_inputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
1827 for t in 0..seq_len {
1828 let slice = input.narrow(1, t, 1)?; timestep_inputs.push(slice.squeeze_t(1)?); }
1831
1832 let mut layer_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
1834 for l in 0..self.num_layers {
1835 layer_h.push(h_init.narrow(0, l, 1)?.squeeze_t(0)?);
1836 }
1837
1838 let mut layer_outputs: Vec<Tensor<T>> = timestep_inputs;
1840 let mut final_h: Vec<Tensor<T>> = Vec::with_capacity(self.num_layers);
1841
1842 for (l, params) in self.layers.iter().enumerate() {
1843 let mut h = layer_h[l].clone();
1844 let mut next_layer_outputs: Vec<Tensor<T>> = Vec::with_capacity(seq_len);
1845
1846 let wih_t = transpose_2d(params.weight_ih.tensor())?.contiguous()?;
1851 let whh_t = transpose_2d(params.weight_hh.tensor())?.contiguous()?;
1852
1853 let bias_ih_2d = broadcast_bias_to_batch(¶ms.bias_ih, batch)?;
1860 let bias_hh_2d = broadcast_bias_to_batch(¶ms.bias_hh, batch)?;
1861 let xw_all = batched_input_projection(&layer_outputs, &wih_t)?;
1862
1863 for (t, _x_t) in layer_outputs.iter().enumerate() {
1864 let xw = xw_all.narrow(0, t * batch, batch)?; let hw = mm(&h, &whh_t)?; let pre_act = add(&add(&add(&xw, &bias_ih_2d)?, &hw)?, &bias_hh_2d)?;
1868
1869 let h_new = match self.nonlinearity {
1870 RNNNonlinearity::Tanh => tanh(&pre_act)?,
1871 RNNNonlinearity::ReLU => relu(&pre_act)?,
1872 };
1873
1874 next_layer_outputs.push(h_new.clone());
1875 h = h_new;
1876 }
1877
1878 final_h.push(h);
1879 layer_outputs = next_layer_outputs;
1880 }
1881
1882 let output = if seq_len == 1 {
1886 reshape(&layer_outputs[0], &[batch as isize, 1, hs as isize])?
1887 } else {
1888 let stacked = cat(&layer_outputs, 1)?;
1889 reshape(&stacked, &[batch as isize, seq_len as isize, hs as isize])?
1890 };
1891
1892 let h_n = if self.num_layers == 1 {
1894 reshape(&final_h[0], &[1, batch as isize, hs as isize])?
1895 } else {
1896 let h_stacked = cat(&final_h, 0)?;
1897 reshape(
1898 &h_stacked,
1899 &[self.num_layers as isize, batch as isize, hs as isize],
1900 )?
1901 };
1902
1903 Ok((output, h_n))
1904 }
1905
1906 #[inline]
1908 pub fn input_size(&self) -> usize {
1909 self.input_size
1910 }
1911
1912 #[inline]
1914 pub fn hidden_size(&self) -> usize {
1915 self.hidden_size
1916 }
1917
1918 #[inline]
1920 pub fn num_layers(&self) -> usize {
1921 self.num_layers
1922 }
1923
1924 #[inline]
1926 pub fn nonlinearity(&self) -> RNNNonlinearity {
1927 self.nonlinearity
1928 }
1929}
1930
1931impl<T: Float> Module<T> for RNN<T> {
1932 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1934 let (output, _) = self.forward_with_state(input, None)?;
1935 Ok(output)
1936 }
1937
1938 fn parameters(&self) -> Vec<&Parameter<T>> {
1939 let mut params = Vec::with_capacity(self.num_layers * 4);
1940 for layer in &self.layers {
1941 params.push(&layer.weight_ih);
1942 params.push(&layer.weight_hh);
1943 params.push(&layer.bias_ih);
1944 params.push(&layer.bias_hh);
1945 }
1946 params
1947 }
1948
1949 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1950 let mut params = Vec::with_capacity(self.num_layers * 4);
1951 for layer in &mut self.layers {
1952 params.push(&mut layer.weight_ih);
1953 params.push(&mut layer.weight_hh);
1954 params.push(&mut layer.bias_ih);
1955 params.push(&mut layer.bias_hh);
1956 }
1957 params
1958 }
1959
1960 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1961 let mut params = Vec::with_capacity(self.num_layers * 4);
1962 for (i, layer) in self.layers.iter().enumerate() {
1963 params.push((format!("layers.{i}.weight_ih"), &layer.weight_ih));
1964 params.push((format!("layers.{i}.weight_hh"), &layer.weight_hh));
1965 params.push((format!("layers.{i}.bias_ih"), &layer.bias_ih));
1966 params.push((format!("layers.{i}.bias_hh"), &layer.bias_hh));
1967 }
1968 params
1969 }
1970
1971 fn train(&mut self) {
1972 self.training = true;
1973 }
1974
1975 fn eval(&mut self) {
1976 self.training = false;
1977 }
1978
1979 fn is_training(&self) -> bool {
1980 self.training
1981 }
1982}
1983
1984#[cfg(test)]
1989mod tests {
1990 use super::*;
1991
1992 #[test]
1997 fn test_lstm_new_basic() {
1998 let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
1999 assert_eq!(lstm.input_size(), 10);
2000 assert_eq!(lstm.hidden_size(), 20);
2001 assert_eq!(lstm.num_layers(), 1);
2002 }
2003
2004 #[test]
2005 fn test_lstm_parameter_count() {
2006 let lstm = LSTM::<f32>::new(10, 20, 2).unwrap();
2007 let params = lstm.parameters();
2010 assert_eq!(params.len(), 8); }
2012
2013 #[test]
2014 fn test_lstm_parameter_shapes() {
2015 let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
2016 let params = lstm.parameters();
2017 assert_eq!(params[0].shape(), &[80, 10]);
2019 assert_eq!(params[1].shape(), &[80, 20]);
2021 assert_eq!(params[2].shape(), &[80]);
2023 assert_eq!(params[3].shape(), &[80]);
2025 }
2026
2027 #[test]
2028 fn test_lstm_new_invalid_num_layers() {
2029 assert!(LSTM::<f32>::new(10, 20, 0).is_err());
2030 }
2031
2032 #[test]
2033 fn test_lstm_new_invalid_hidden_size() {
2034 assert!(LSTM::<f32>::new(10, 0, 1).is_err());
2035 }
2036
2037 #[test]
2038 fn test_lstm_new_invalid_input_size() {
2039 assert!(LSTM::<f32>::new(0, 20, 1).is_err());
2040 }
2041
2042 #[test]
2047 fn test_lstm_weight_init_range() {
2048 let hs = 100;
2049 let lstm = LSTM::<f32>::new(50, hs, 1).unwrap();
2050 let k = 1.0 / (hs as f32).sqrt();
2051 let params = lstm.parameters();
2052
2053 for param in ¶ms[..2] {
2055 let data = param.data().unwrap();
2056 for &v in data {
2057 assert!(
2058 v.abs() <= k + 0.01,
2059 "weight value {v} exceeds expected range [-{k}, {k}]"
2060 );
2061 }
2062 }
2063
2064 for param in ¶ms[2..4] {
2066 let data = param.data().unwrap();
2067 assert!(
2068 data.iter().all(|&v| v == 0.0),
2069 "bias should be initialized to zeros"
2070 );
2071 }
2072 }
2073
2074 #[test]
2079 fn test_lstm_forward_output_shape() {
2080 let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
2081 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap(); let (output, (h_n, c_n)) = lstm.forward_with_state(&input, None).unwrap();
2084
2085 assert_eq!(output.shape(), &[2, 5, 20]); assert_eq!(h_n.shape(), &[1, 2, 20]); assert_eq!(c_n.shape(), &[1, 2, 20]);
2088 }
2089
2090 #[test]
2091 fn test_lstm_forward_multi_layer_shapes() {
2092 let lstm = LSTM::<f32>::new(8, 16, 3).unwrap();
2093 let input = ferrotorch_core::zeros::<f32>(&[4, 7, 8]).unwrap(); let (output, (h_n, c_n)) = lstm.forward_with_state(&input, None).unwrap();
2096
2097 assert_eq!(output.shape(), &[4, 7, 16]); assert_eq!(h_n.shape(), &[3, 4, 16]); assert_eq!(c_n.shape(), &[3, 4, 16]);
2100 }
2101
2102 #[test]
2103 fn test_lstm_module_forward_shape() {
2104 let lstm = LSTM::<f32>::new(10, 20, 1).unwrap();
2105 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
2106
2107 let output = lstm.forward(&input).unwrap();
2108 assert_eq!(output.shape(), &[2, 5, 20]);
2109 }
2110
2111 #[test]
2116 fn test_lstm_forward_does_not_error() {
2117 let lstm = LSTM::<f32>::new(4, 8, 2).unwrap();
2118 let input = ferrotorch_core::randn::<f32>(&[3, 10, 4]).unwrap();
2119
2120 let result = lstm.forward_with_state(&input, None);
2121 assert!(
2122 result.is_ok(),
2123 "forward should not error: {:?}",
2124 result.err()
2125 );
2126 }
2127
2128 #[test]
2129 fn test_lstm_forward_nonzero_output() {
2130 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2132 let input = ferrotorch_core::randn::<f32>(&[1, 3, 4]).unwrap();
2133
2134 let (output, _) = lstm.forward_with_state(&input, None).unwrap();
2135 let data = output.data().unwrap();
2136 let any_nonzero = data.iter().any(|&v| v.abs() > 1e-10);
2137 assert!(any_nonzero, "output should have non-zero values");
2138 }
2139
2140 #[test]
2141 fn test_lstm_forward_seq_len_1() {
2142 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2143 let input = ferrotorch_core::zeros::<f32>(&[1, 1, 4]).unwrap();
2144
2145 let (output, (h_n, c_n)) = lstm.forward_with_state(&input, None).unwrap();
2146 assert_eq!(output.shape(), &[1, 1, 8]);
2147 assert_eq!(h_n.shape(), &[1, 1, 8]);
2148 assert_eq!(c_n.shape(), &[1, 1, 8]);
2149 }
2150
2151 #[test]
2156 fn test_lstm_forward_with_initial_state() {
2157 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2158
2159 let h0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2160 let c0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2161 let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2162
2163 let result = lstm.forward_with_state(&input, Some((&h0, &c0)));
2164 assert!(result.is_ok());
2165 }
2166
2167 #[test]
2168 fn test_lstm_forward_state_shape_mismatch() {
2169 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2170
2171 let h0 = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2173 let c0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2174 let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2175
2176 assert!(lstm.forward_with_state(&input, Some((&h0, &c0))).is_err());
2177 }
2178
2179 #[test]
2180 fn test_lstm_forward_input_wrong_ndim() {
2181 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2182 let input = ferrotorch_core::zeros::<f32>(&[10, 4]).unwrap(); assert!(lstm.forward_with_state(&input, None).is_err());
2184 }
2185
2186 #[test]
2187 fn test_lstm_forward_input_size_mismatch() {
2188 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2189 let input = ferrotorch_core::zeros::<f32>(&[1, 5, 7]).unwrap(); assert!(lstm.forward_with_state(&input, None).is_err());
2191 }
2192
2193 #[test]
2198 fn test_lstm_named_parameters() {
2199 let lstm = LSTM::<f32>::new(4, 8, 2).unwrap();
2200 let named = lstm.named_parameters();
2201 assert_eq!(named.len(), 8);
2202 assert_eq!(named[0].0, "layers.0.weight_ih");
2203 assert_eq!(named[1].0, "layers.0.weight_hh");
2204 assert_eq!(named[2].0, "layers.0.bias_ih");
2205 assert_eq!(named[3].0, "layers.0.bias_hh");
2206 assert_eq!(named[4].0, "layers.1.weight_ih");
2207 assert_eq!(named[5].0, "layers.1.weight_hh");
2208 assert_eq!(named[6].0, "layers.1.bias_ih");
2209 assert_eq!(named[7].0, "layers.1.bias_hh");
2210 }
2211
2212 #[test]
2213 fn test_lstm_train_eval() {
2214 let mut lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2215 assert!(lstm.is_training());
2216 lstm.eval();
2217 assert!(!lstm.is_training());
2218 lstm.train();
2219 assert!(lstm.is_training());
2220 }
2221
2222 #[test]
2223 fn test_lstm_all_parameters_require_grad() {
2224 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2225 for param in lstm.parameters() {
2226 assert!(param.requires_grad());
2227 }
2228 }
2229
2230 #[test]
2231 fn test_lstm_is_send_sync() {
2232 fn assert_send_sync<T: Send + Sync>() {}
2233 assert_send_sync::<LSTM<f32>>();
2234 assert_send_sync::<LSTM<f64>>();
2235 }
2236
2237 #[test]
2242 fn test_lstm_multi_layer_weight_shapes() {
2243 let lstm = LSTM::<f32>::new(10, 20, 3).unwrap();
2244 let params = lstm.parameters();
2245
2246 assert_eq!(params[0].shape(), &[80, 10]);
2248 assert_eq!(params[1].shape(), &[80, 20]);
2250
2251 assert_eq!(params[4].shape(), &[80, 20]);
2253 assert_eq!(params[5].shape(), &[80, 20]);
2255
2256 assert_eq!(params[8].shape(), &[80, 20]);
2258 }
2259
2260 #[test]
2265 fn test_lstm_state_dict_roundtrip() {
2266 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2267 let sd = lstm.state_dict();
2268 assert_eq!(sd.len(), 4);
2269 assert!(sd.contains_key("layers.0.weight_ih"));
2270 assert!(sd.contains_key("layers.0.weight_hh"));
2271 assert!(sd.contains_key("layers.0.bias_ih"));
2272 assert!(sd.contains_key("layers.0.bias_hh"));
2273
2274 let mut lstm2 = LSTM::<f32>::new(4, 8, 1).unwrap();
2275 lstm2.load_state_dict(&sd, true).unwrap();
2276 }
2277
2278 #[test]
2283 fn test_lstm_deterministic() {
2284 let lstm = LSTM::<f32>::new(4, 8, 1).unwrap();
2285 let input = ferrotorch_core::from_slice::<f32>(
2286 &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
2287 &[1, 2, 4],
2288 )
2289 .unwrap();
2290
2291 let (out1, _) = lstm.forward_with_state(&input, None).unwrap();
2292 let (out2, _) = lstm.forward_with_state(&input, None).unwrap();
2293
2294 let d1 = out1.data().unwrap();
2295 let d2 = out2.data().unwrap();
2296 for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
2297 assert!(
2298 (a - b).abs() < 1e-6,
2299 "output mismatch at index {i}: {a} vs {b}"
2300 );
2301 }
2302 }
2303
2304 #[test]
2313 fn test_gru_new_basic() {
2314 let gru = GRU::<f32>::new(10, 20).unwrap();
2315 assert_eq!(gru.input_size(), 10);
2316 assert_eq!(gru.hidden_size(), 20);
2317 assert_eq!(gru.num_layers(), 1);
2318 }
2319
2320 #[test]
2321 fn test_gru_with_num_layers() {
2322 let gru = GRU::<f32>::with_num_layers(10, 20, 3).unwrap();
2323 assert_eq!(gru.num_layers(), 3);
2324 }
2325
2326 #[test]
2327 fn test_gru_parameter_count() {
2328 let gru = GRU::<f32>::with_num_layers(10, 20, 2).unwrap();
2329 let params = gru.parameters();
2332 assert_eq!(params.len(), 8); }
2334
2335 #[test]
2336 fn test_gru_parameter_shapes() {
2337 let gru = GRU::<f32>::new(10, 20).unwrap();
2338 let params = gru.parameters();
2339 assert_eq!(params[0].shape(), &[60, 10]);
2341 assert_eq!(params[1].shape(), &[60, 20]);
2343 assert_eq!(params[2].shape(), &[60]);
2345 assert_eq!(params[3].shape(), &[60]);
2347 }
2348
2349 #[test]
2350 fn test_gru_new_invalid_num_layers() {
2351 assert!(GRU::<f32>::with_num_layers(10, 20, 0).is_err());
2352 }
2353
2354 #[test]
2355 fn test_gru_new_invalid_hidden_size() {
2356 assert!(GRU::<f32>::new(10, 0).is_err());
2357 }
2358
2359 #[test]
2360 fn test_gru_new_invalid_input_size() {
2361 assert!(GRU::<f32>::new(0, 20).is_err());
2362 }
2363
2364 #[test]
2369 fn test_gru_weight_init_range() {
2370 let hs = 100;
2371 let gru = GRU::<f32>::new(50, hs).unwrap();
2372 let k = 1.0 / (hs as f32).sqrt();
2373 let params = gru.parameters();
2374
2375 for param in ¶ms[..2] {
2377 let data = param.data().unwrap();
2378 for &v in data {
2379 assert!(
2380 v.abs() <= k + 0.01,
2381 "weight value {v} exceeds expected range [-{k}, {k}]"
2382 );
2383 }
2384 }
2385
2386 for param in ¶ms[2..4] {
2388 let data = param.data().unwrap();
2389 assert!(
2390 data.iter().all(|&v| v == 0.0),
2391 "bias should be initialized to zeros"
2392 );
2393 }
2394 }
2395
2396 #[test]
2401 fn test_gru_forward_output_shape() {
2402 let gru = GRU::<f32>::new(10, 20).unwrap();
2403 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
2404
2405 let (output, h_n) = gru.forward(&input, None).unwrap();
2406
2407 assert_eq!(output.shape(), &[2, 5, 20]); assert_eq!(h_n.shape(), &[1, 2, 20]); }
2410
2411 #[test]
2412 fn test_gru_forward_multi_layer_shapes() {
2413 let gru = GRU::<f32>::with_num_layers(8, 16, 3).unwrap();
2414 let input = ferrotorch_core::zeros::<f32>(&[4, 7, 8]).unwrap();
2415
2416 let (output, h_n) = gru.forward(&input, None).unwrap();
2417
2418 assert_eq!(output.shape(), &[4, 7, 16]);
2419 assert_eq!(h_n.shape(), &[3, 4, 16]);
2420 }
2421
2422 #[test]
2423 fn test_gru_module_forward_shape() {
2424 let gru = GRU::<f32>::new(10, 20).unwrap();
2425 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
2426
2427 let output = <GRU<f32> as Module<f32>>::forward(&gru, &input).unwrap();
2428 assert_eq!(output.shape(), &[2, 5, 20]);
2429 }
2430
2431 #[test]
2436 fn test_gru_forward_does_not_error() {
2437 let gru = GRU::<f32>::with_num_layers(4, 8, 2).unwrap();
2438 let input = ferrotorch_core::randn::<f32>(&[3, 10, 4]).unwrap();
2439
2440 let result = gru.forward(&input, None);
2441 assert!(
2442 result.is_ok(),
2443 "forward should not error: {:?}",
2444 result.err()
2445 );
2446 }
2447
2448 #[test]
2449 fn test_gru_forward_nonzero_output() {
2450 let gru = GRU::<f32>::new(4, 8).unwrap();
2451 let input = ferrotorch_core::randn::<f32>(&[1, 3, 4]).unwrap();
2452
2453 let (output, _) = gru.forward(&input, None).unwrap();
2454 let data = output.data().unwrap();
2455 let any_nonzero = data.iter().any(|&v| v.abs() > 1e-10);
2456 assert!(any_nonzero, "output should have non-zero values");
2457 }
2458
2459 #[test]
2460 fn test_gru_forward_seq_len_1() {
2461 let gru = GRU::<f32>::new(4, 8).unwrap();
2462 let input = ferrotorch_core::zeros::<f32>(&[1, 1, 4]).unwrap();
2463
2464 let (output, h_n) = gru.forward(&input, None).unwrap();
2465 assert_eq!(output.shape(), &[1, 1, 8]);
2466 assert_eq!(h_n.shape(), &[1, 1, 8]);
2467 }
2468
2469 #[test]
2474 fn test_gru_forward_with_initial_state() {
2475 let gru = GRU::<f32>::new(4, 8).unwrap();
2476
2477 let h0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
2478 let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2479
2480 let result = gru.forward(&input, Some(&h0));
2481 assert!(result.is_ok());
2482 }
2483
2484 #[test]
2485 fn test_gru_forward_state_shape_mismatch() {
2486 let gru = GRU::<f32>::new(4, 8).unwrap();
2487
2488 let h0 = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
2490 let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
2491
2492 assert!(gru.forward(&input, Some(&h0)).is_err());
2493 }
2494
2495 #[test]
2496 fn test_gru_forward_input_wrong_ndim() {
2497 let gru = GRU::<f32>::new(4, 8).unwrap();
2498 let input = ferrotorch_core::zeros::<f32>(&[10, 4]).unwrap();
2499 assert!(gru.forward(&input, None).is_err());
2500 }
2501
2502 #[test]
2503 fn test_gru_forward_input_size_mismatch() {
2504 let gru = GRU::<f32>::new(4, 8).unwrap();
2505 let input = ferrotorch_core::zeros::<f32>(&[1, 5, 7]).unwrap();
2506 assert!(gru.forward(&input, None).is_err());
2507 }
2508
2509 #[test]
2514 fn test_gru_named_parameters() {
2515 let gru = GRU::<f32>::with_num_layers(4, 8, 2).unwrap();
2516 let named = gru.named_parameters();
2517 assert_eq!(named.len(), 8);
2518 assert_eq!(named[0].0, "layers.0.weight_ih");
2519 assert_eq!(named[1].0, "layers.0.weight_hh");
2520 assert_eq!(named[2].0, "layers.0.bias_ih");
2521 assert_eq!(named[3].0, "layers.0.bias_hh");
2522 assert_eq!(named[4].0, "layers.1.weight_ih");
2523 assert_eq!(named[5].0, "layers.1.weight_hh");
2524 assert_eq!(named[6].0, "layers.1.bias_ih");
2525 assert_eq!(named[7].0, "layers.1.bias_hh");
2526 }
2527
2528 #[test]
2529 fn test_gru_train_eval() {
2530 let mut gru = GRU::<f32>::new(4, 8).unwrap();
2531 assert!(gru.is_training());
2532 gru.eval();
2533 assert!(!gru.is_training());
2534 gru.train();
2535 assert!(gru.is_training());
2536 }
2537
2538 #[test]
2539 fn test_gru_all_parameters_require_grad() {
2540 let gru = GRU::<f32>::new(4, 8).unwrap();
2541 for param in gru.parameters() {
2542 assert!(param.requires_grad());
2543 }
2544 }
2545
2546 #[test]
2547 fn test_gru_is_send_sync() {
2548 fn assert_send_sync<T: Send + Sync>() {}
2549 assert_send_sync::<GRU<f32>>();
2550 assert_send_sync::<GRU<f64>>();
2551 }
2552
2553 #[test]
2558 fn test_gru_multi_layer_weight_shapes() {
2559 let gru = GRU::<f32>::with_num_layers(10, 20, 3).unwrap();
2560 let params = gru.parameters();
2561
2562 assert_eq!(params[0].shape(), &[60, 10]);
2564 assert_eq!(params[1].shape(), &[60, 20]);
2566
2567 assert_eq!(params[4].shape(), &[60, 20]);
2569 assert_eq!(params[5].shape(), &[60, 20]);
2571
2572 assert_eq!(params[8].shape(), &[60, 20]);
2574 }
2575
2576 #[test]
2581 fn test_gru_state_dict_roundtrip() {
2582 let gru = GRU::<f32>::new(4, 8).unwrap();
2583 let sd = gru.state_dict();
2584 assert_eq!(sd.len(), 4);
2585 assert!(sd.contains_key("layers.0.weight_ih"));
2586 assert!(sd.contains_key("layers.0.weight_hh"));
2587 assert!(sd.contains_key("layers.0.bias_ih"));
2588 assert!(sd.contains_key("layers.0.bias_hh"));
2589
2590 let mut gru2 = GRU::<f32>::new(4, 8).unwrap();
2591 gru2.load_state_dict(&sd, true).unwrap();
2592 }
2593
2594 #[test]
2599 fn test_gru_deterministic() {
2600 let gru = GRU::<f32>::new(4, 8).unwrap();
2601 let input = ferrotorch_core::from_slice::<f32>(
2602 &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
2603 &[1, 2, 4],
2604 )
2605 .unwrap();
2606
2607 let (out1, _) = gru.forward(&input, None).unwrap();
2608 let (out2, _) = gru.forward(&input, None).unwrap();
2609
2610 let d1 = out1.data().unwrap();
2611 let d2 = out2.data().unwrap();
2612 for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
2613 assert!(
2614 (a - b).abs() < 1e-6,
2615 "output mismatch at index {i}: {a} vs {b}"
2616 );
2617 }
2618 }
2619
2620 #[test]
2625 fn test_rnn_cell_new_basic() {
2626 let cell = RNNCell::<f32>::new(10, 20).unwrap();
2627 assert_eq!(cell.input_size(), 10);
2628 assert_eq!(cell.hidden_size(), 20);
2629 assert_eq!(cell.nonlinearity(), RNNNonlinearity::Tanh);
2630 }
2631
2632 #[test]
2633 fn test_rnn_cell_relu() {
2634 let cell = RNNCell::<f32>::with_nonlinearity(10, 20, RNNNonlinearity::ReLU).unwrap();
2635 assert_eq!(cell.nonlinearity(), RNNNonlinearity::ReLU);
2636 }
2637
2638 #[test]
2639 fn test_rnn_cell_invalid_sizes() {
2640 assert!(RNNCell::<f32>::new(0, 20).is_err());
2641 assert!(RNNCell::<f32>::new(10, 0).is_err());
2642 }
2643
2644 #[test]
2645 fn test_rnn_cell_parameter_shapes() {
2646 let cell = RNNCell::<f32>::new(10, 20).unwrap();
2647 let params = cell.parameters();
2648 assert_eq!(params.len(), 4);
2649 assert_eq!(params[0].shape(), &[20, 10]); assert_eq!(params[1].shape(), &[20, 20]); assert_eq!(params[2].shape(), &[20]); assert_eq!(params[3].shape(), &[20]); }
2654
2655 #[test]
2656 fn test_rnn_cell_forward_output_shape() {
2657 let cell = RNNCell::<f32>::new(10, 20).unwrap();
2658 let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2659 let h = cell.forward_cell(&x, None).unwrap();
2660 assert_eq!(h.shape(), &[3, 20]);
2661 }
2662
2663 #[test]
2664 fn test_rnn_cell_forward_with_hidden() {
2665 let cell = RNNCell::<f32>::new(10, 20).unwrap();
2666 let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2667 let h0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2668 let h = cell.forward_cell(&x, Some(&h0)).unwrap();
2669 assert_eq!(h.shape(), &[3, 20]);
2670 }
2671
2672 #[test]
2673 fn test_rnn_cell_forward_nonzero() {
2674 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2675 let x = ferrotorch_core::randn::<f32>(&[1, 4]).unwrap();
2676 let h = cell.forward_cell(&x, None).unwrap();
2677 let data = h.data().unwrap();
2678 assert!(data.iter().any(|&v| v.abs() > 1e-10));
2679 }
2680
2681 #[test]
2682 fn test_rnn_cell_forward_bad_input_ndim() {
2683 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2684 let x = ferrotorch_core::zeros::<f32>(&[1, 2, 4]).unwrap();
2685 assert!(cell.forward_cell(&x, None).is_err());
2686 }
2687
2688 #[test]
2689 fn test_rnn_cell_forward_bad_input_size() {
2690 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2691 let x = ferrotorch_core::zeros::<f32>(&[1, 7]).unwrap();
2692 assert!(cell.forward_cell(&x, None).is_err());
2693 }
2694
2695 #[test]
2696 fn test_rnn_cell_forward_bad_h_shape() {
2697 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2698 let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2699 let h0 = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap(); assert!(cell.forward_cell(&x, Some(&h0)).is_err());
2701 }
2702
2703 #[test]
2704 fn test_rnn_cell_module_forward() {
2705 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2706 let x = ferrotorch_core::randn::<f32>(&[2, 4]).unwrap();
2707 let h = <RNNCell<f32> as Module<f32>>::forward(&cell, &x).unwrap();
2708 assert_eq!(h.shape(), &[2, 8]);
2709 }
2710
2711 #[test]
2712 fn test_rnn_cell_named_parameters() {
2713 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2714 let named = cell.named_parameters();
2715 assert_eq!(named.len(), 4);
2716 assert_eq!(named[0].0, "weight_ih");
2717 assert_eq!(named[1].0, "weight_hh");
2718 assert_eq!(named[2].0, "bias_ih");
2719 assert_eq!(named[3].0, "bias_hh");
2720 }
2721
2722 #[test]
2723 fn test_rnn_cell_train_eval() {
2724 let mut cell = RNNCell::<f32>::new(4, 8).unwrap();
2725 assert!(cell.is_training());
2726 cell.eval();
2727 assert!(!cell.is_training());
2728 cell.train();
2729 assert!(cell.is_training());
2730 }
2731
2732 #[test]
2733 fn test_rnn_cell_deterministic() {
2734 let cell = RNNCell::<f32>::new(4, 8).unwrap();
2735 let x = ferrotorch_core::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &[1, 4]).unwrap();
2736 let h1 = cell.forward_cell(&x, None).unwrap();
2737 let h2 = cell.forward_cell(&x, None).unwrap();
2738 let d1 = h1.data().unwrap();
2739 let d2 = h2.data().unwrap();
2740 for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
2741 assert!((a - b).abs() < 1e-6, "mismatch at {i}: {a} vs {b}");
2742 }
2743 }
2744
2745 #[test]
2746 fn test_rnn_cell_relu_output_nonneg() {
2747 let cell = RNNCell::<f32>::with_nonlinearity(4, 8, RNNNonlinearity::ReLU).unwrap();
2749 let x = ferrotorch_core::randn::<f32>(&[5, 4]).unwrap();
2750 let h = cell.forward_cell(&x, None).unwrap();
2751 let data = h.data().unwrap();
2752 assert!(
2753 data.iter().all(|&v| v >= 0.0),
2754 "relu output should be non-negative"
2755 );
2756 }
2757
2758 #[test]
2759 fn test_rnn_cell_is_send_sync() {
2760 fn assert_send_sync<T: Send + Sync>() {}
2761 assert_send_sync::<RNNCell<f32>>();
2762 assert_send_sync::<RNNCell<f64>>();
2763 }
2764
2765 #[test]
2770 fn test_lstm_cell_new_basic() {
2771 let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2772 assert_eq!(cell.input_size(), 10);
2773 assert_eq!(cell.hidden_size(), 20);
2774 }
2775
2776 #[test]
2777 fn test_lstm_cell_invalid_sizes() {
2778 assert!(LSTMCell::<f32>::new(0, 20).is_err());
2779 assert!(LSTMCell::<f32>::new(10, 0).is_err());
2780 }
2781
2782 #[test]
2783 fn test_lstm_cell_parameter_shapes() {
2784 let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2785 let params = cell.parameters();
2786 assert_eq!(params.len(), 4);
2787 assert_eq!(params[0].shape(), &[80, 10]); assert_eq!(params[1].shape(), &[80, 20]); assert_eq!(params[2].shape(), &[80]); assert_eq!(params[3].shape(), &[80]); }
2792
2793 #[test]
2794 fn test_lstm_cell_forward_output_shape() {
2795 let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2796 let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2797 let (h, c) = cell.forward_cell(&x, None).unwrap();
2798 assert_eq!(h.shape(), &[3, 20]);
2799 assert_eq!(c.shape(), &[3, 20]);
2800 }
2801
2802 #[test]
2803 fn test_lstm_cell_forward_with_state() {
2804 let cell = LSTMCell::<f32>::new(10, 20).unwrap();
2805 let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2806 let h0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2807 let c0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2808 let (h, c) = cell.forward_cell(&x, Some((&h0, &c0))).unwrap();
2809 assert_eq!(h.shape(), &[3, 20]);
2810 assert_eq!(c.shape(), &[3, 20]);
2811 }
2812
2813 #[test]
2814 fn test_lstm_cell_forward_nonzero() {
2815 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2816 let x = ferrotorch_core::randn::<f32>(&[1, 4]).unwrap();
2817 let (h, c) = cell.forward_cell(&x, None).unwrap();
2818 let hd = h.data().unwrap();
2819 let cd = c.data().unwrap();
2820 assert!(hd.iter().any(|&v| v.abs() > 1e-10));
2821 assert!(cd.iter().any(|&v| v.abs() > 1e-10));
2822 }
2823
2824 #[test]
2825 fn test_lstm_cell_forward_bad_input_ndim() {
2826 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2827 let x = ferrotorch_core::zeros::<f32>(&[1, 2, 4]).unwrap();
2828 assert!(cell.forward_cell(&x, None).is_err());
2829 }
2830
2831 #[test]
2832 fn test_lstm_cell_forward_bad_input_size() {
2833 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2834 let x = ferrotorch_core::zeros::<f32>(&[1, 7]).unwrap();
2835 assert!(cell.forward_cell(&x, None).is_err());
2836 }
2837
2838 #[test]
2839 fn test_lstm_cell_forward_bad_h_shape() {
2840 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2841 let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2842 let h0 = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap(); let c0 = ferrotorch_core::zeros::<f32>(&[2, 8]).unwrap();
2844 assert!(cell.forward_cell(&x, Some((&h0, &c0))).is_err());
2845 }
2846
2847 #[test]
2848 fn test_lstm_cell_forward_bad_c_shape() {
2849 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2850 let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2851 let h0 = ferrotorch_core::zeros::<f32>(&[2, 8]).unwrap();
2852 let c0 = ferrotorch_core::zeros::<f32>(&[2, 99]).unwrap(); assert!(cell.forward_cell(&x, Some((&h0, &c0))).is_err());
2854 }
2855
2856 #[test]
2857 fn test_lstm_cell_module_forward() {
2858 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2859 let x = ferrotorch_core::randn::<f32>(&[2, 4]).unwrap();
2860 let h = <LSTMCell<f32> as Module<f32>>::forward(&cell, &x).unwrap();
2862 assert_eq!(h.shape(), &[2, 8]);
2863 }
2864
2865 #[test]
2866 fn test_lstm_cell_named_parameters() {
2867 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2868 let named = cell.named_parameters();
2869 assert_eq!(named.len(), 4);
2870 assert_eq!(named[0].0, "weight_ih");
2871 assert_eq!(named[1].0, "weight_hh");
2872 assert_eq!(named[2].0, "bias_ih");
2873 assert_eq!(named[3].0, "bias_hh");
2874 }
2875
2876 #[test]
2877 fn test_lstm_cell_train_eval() {
2878 let mut cell = LSTMCell::<f32>::new(4, 8).unwrap();
2879 assert!(cell.is_training());
2880 cell.eval();
2881 assert!(!cell.is_training());
2882 cell.train();
2883 assert!(cell.is_training());
2884 }
2885
2886 #[test]
2887 fn test_lstm_cell_deterministic() {
2888 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2889 let x = ferrotorch_core::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &[1, 4]).unwrap();
2890 let (h1, c1) = cell.forward_cell(&x, None).unwrap();
2891 let (h2, c2) = cell.forward_cell(&x, None).unwrap();
2892 let hd1 = h1.data().unwrap();
2893 let hd2 = h2.data().unwrap();
2894 let cd1 = c1.data().unwrap();
2895 let cd2 = c2.data().unwrap();
2896 for (i, (&a, &b)) in hd1.iter().zip(hd2.iter()).enumerate() {
2897 assert!((a - b).abs() < 1e-6, "h mismatch at {i}: {a} vs {b}");
2898 }
2899 for (i, (&a, &b)) in cd1.iter().zip(cd2.iter()).enumerate() {
2900 assert!((a - b).abs() < 1e-6, "c mismatch at {i}: {a} vs {b}");
2901 }
2902 }
2903
2904 #[test]
2905 fn test_lstm_cell_h_bounded_by_tanh() {
2906 let cell = LSTMCell::<f32>::new(4, 8).unwrap();
2908 let x = ferrotorch_core::randn::<f32>(&[10, 4]).unwrap();
2909 let (h, _c) = cell.forward_cell(&x, None).unwrap();
2910 let data = h.data().unwrap();
2911 assert!(
2912 data.iter().all(|&v| v.abs() <= 1.0 + 1e-6),
2913 "LSTM cell h should be bounded by [-1, 1]"
2914 );
2915 }
2916
2917 #[test]
2918 fn test_lstm_cell_is_send_sync() {
2919 fn assert_send_sync<T: Send + Sync>() {}
2920 assert_send_sync::<LSTMCell<f32>>();
2921 assert_send_sync::<LSTMCell<f64>>();
2922 }
2923
2924 #[test]
2929 fn test_gru_cell_new_basic() {
2930 let cell = GRUCell::<f32>::new(10, 20).unwrap();
2931 assert_eq!(cell.input_size(), 10);
2932 assert_eq!(cell.hidden_size(), 20);
2933 }
2934
2935 #[test]
2936 fn test_gru_cell_invalid_sizes() {
2937 assert!(GRUCell::<f32>::new(0, 20).is_err());
2938 assert!(GRUCell::<f32>::new(10, 0).is_err());
2939 }
2940
2941 #[test]
2942 fn test_gru_cell_parameter_shapes() {
2943 let cell = GRUCell::<f32>::new(10, 20).unwrap();
2944 let params = cell.parameters();
2945 assert_eq!(params.len(), 4);
2946 assert_eq!(params[0].shape(), &[60, 10]); assert_eq!(params[1].shape(), &[60, 20]); assert_eq!(params[2].shape(), &[60]); assert_eq!(params[3].shape(), &[60]); }
2951
2952 #[test]
2953 fn test_gru_cell_forward_output_shape() {
2954 let cell = GRUCell::<f32>::new(10, 20).unwrap();
2955 let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2956 let h = cell.forward_cell(&x, None).unwrap();
2957 assert_eq!(h.shape(), &[3, 20]);
2958 }
2959
2960 #[test]
2961 fn test_gru_cell_forward_with_hidden() {
2962 let cell = GRUCell::<f32>::new(10, 20).unwrap();
2963 let x = ferrotorch_core::randn::<f32>(&[3, 10]).unwrap();
2964 let h0 = ferrotorch_core::randn::<f32>(&[3, 20]).unwrap();
2965 let h = cell.forward_cell(&x, Some(&h0)).unwrap();
2966 assert_eq!(h.shape(), &[3, 20]);
2967 }
2968
2969 #[test]
2970 fn test_gru_cell_forward_nonzero() {
2971 let cell = GRUCell::<f32>::new(4, 8).unwrap();
2972 let x = ferrotorch_core::randn::<f32>(&[1, 4]).unwrap();
2973 let h = cell.forward_cell(&x, None).unwrap();
2974 let data = h.data().unwrap();
2975 assert!(data.iter().any(|&v| v.abs() > 1e-10));
2976 }
2977
2978 #[test]
2979 fn test_gru_cell_forward_bad_input_ndim() {
2980 let cell = GRUCell::<f32>::new(4, 8).unwrap();
2981 let x = ferrotorch_core::zeros::<f32>(&[1, 2, 4]).unwrap();
2982 assert!(cell.forward_cell(&x, None).is_err());
2983 }
2984
2985 #[test]
2986 fn test_gru_cell_forward_bad_input_size() {
2987 let cell = GRUCell::<f32>::new(4, 8).unwrap();
2988 let x = ferrotorch_core::zeros::<f32>(&[1, 7]).unwrap();
2989 assert!(cell.forward_cell(&x, None).is_err());
2990 }
2991
2992 #[test]
2993 fn test_gru_cell_forward_bad_h_shape() {
2994 let cell = GRUCell::<f32>::new(4, 8).unwrap();
2995 let x = ferrotorch_core::zeros::<f32>(&[2, 4]).unwrap();
2996 let h0 = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap(); assert!(cell.forward_cell(&x, Some(&h0)).is_err());
2998 }
2999
3000 #[test]
3001 fn test_gru_cell_module_forward() {
3002 let cell = GRUCell::<f32>::new(4, 8).unwrap();
3003 let x = ferrotorch_core::randn::<f32>(&[2, 4]).unwrap();
3004 let h = <GRUCell<f32> as Module<f32>>::forward(&cell, &x).unwrap();
3005 assert_eq!(h.shape(), &[2, 8]);
3006 }
3007
3008 #[test]
3009 fn test_gru_cell_named_parameters() {
3010 let cell = GRUCell::<f32>::new(4, 8).unwrap();
3011 let named = cell.named_parameters();
3012 assert_eq!(named.len(), 4);
3013 assert_eq!(named[0].0, "weight_ih");
3014 assert_eq!(named[1].0, "weight_hh");
3015 assert_eq!(named[2].0, "bias_ih");
3016 assert_eq!(named[3].0, "bias_hh");
3017 }
3018
3019 #[test]
3020 fn test_gru_cell_train_eval() {
3021 let mut cell = GRUCell::<f32>::new(4, 8).unwrap();
3022 assert!(cell.is_training());
3023 cell.eval();
3024 assert!(!cell.is_training());
3025 cell.train();
3026 assert!(cell.is_training());
3027 }
3028
3029 #[test]
3030 fn test_gru_cell_deterministic() {
3031 let cell = GRUCell::<f32>::new(4, 8).unwrap();
3032 let x = ferrotorch_core::from_slice::<f32>(&[0.1, 0.2, 0.3, 0.4], &[1, 4]).unwrap();
3033 let h1 = cell.forward_cell(&x, None).unwrap();
3034 let h2 = cell.forward_cell(&x, None).unwrap();
3035 let d1 = h1.data().unwrap();
3036 let d2 = h2.data().unwrap();
3037 for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
3038 assert!((a - b).abs() < 1e-6, "mismatch at {i}: {a} vs {b}");
3039 }
3040 }
3041
3042 #[test]
3043 fn test_gru_cell_is_send_sync() {
3044 fn assert_send_sync<T: Send + Sync>() {}
3045 assert_send_sync::<GRUCell<f32>>();
3046 assert_send_sync::<GRUCell<f64>>();
3047 }
3048
3049 #[test]
3054 fn test_rnn_new_basic() {
3055 let rnn = RNN::<f32>::new(10, 20).unwrap();
3056 assert_eq!(rnn.input_size(), 10);
3057 assert_eq!(rnn.hidden_size(), 20);
3058 assert_eq!(rnn.num_layers(), 1);
3059 assert_eq!(rnn.nonlinearity(), RNNNonlinearity::Tanh);
3060 }
3061
3062 #[test]
3063 fn test_rnn_with_options() {
3064 let rnn = RNN::<f32>::with_options(10, 20, 3, RNNNonlinearity::ReLU).unwrap();
3065 assert_eq!(rnn.num_layers(), 3);
3066 assert_eq!(rnn.nonlinearity(), RNNNonlinearity::ReLU);
3067 }
3068
3069 #[test]
3070 fn test_rnn_invalid_sizes() {
3071 assert!(RNN::<f32>::with_options(0, 20, 1, RNNNonlinearity::Tanh).is_err());
3072 assert!(RNN::<f32>::with_options(10, 0, 1, RNNNonlinearity::Tanh).is_err());
3073 assert!(RNN::<f32>::with_options(10, 20, 0, RNNNonlinearity::Tanh).is_err());
3074 }
3075
3076 #[test]
3077 fn test_rnn_parameter_count() {
3078 let rnn = RNN::<f32>::with_options(10, 20, 2, RNNNonlinearity::Tanh).unwrap();
3079 let params = rnn.parameters();
3080 assert_eq!(params.len(), 8); }
3082
3083 #[test]
3084 fn test_rnn_parameter_shapes() {
3085 let rnn = RNN::<f32>::new(10, 20).unwrap();
3086 let params = rnn.parameters();
3087 assert_eq!(params[0].shape(), &[20, 10]); assert_eq!(params[1].shape(), &[20, 20]); assert_eq!(params[2].shape(), &[20]); assert_eq!(params[3].shape(), &[20]); }
3092
3093 #[test]
3094 fn test_rnn_multi_layer_weight_shapes() {
3095 let rnn = RNN::<f32>::with_options(10, 20, 3, RNNNonlinearity::Tanh).unwrap();
3096 let params = rnn.parameters();
3097
3098 assert_eq!(params[0].shape(), &[20, 10]);
3100 assert_eq!(params[1].shape(), &[20, 20]);
3102
3103 assert_eq!(params[4].shape(), &[20, 20]);
3105 assert_eq!(params[5].shape(), &[20, 20]);
3107
3108 assert_eq!(params[8].shape(), &[20, 20]);
3110 }
3111
3112 #[test]
3113 fn test_rnn_forward_output_shape() {
3114 let rnn = RNN::<f32>::new(10, 20).unwrap();
3115 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
3116
3117 let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3118
3119 assert_eq!(output.shape(), &[2, 5, 20]); assert_eq!(h_n.shape(), &[1, 2, 20]); }
3122
3123 #[test]
3124 fn test_rnn_forward_multi_layer_shapes() {
3125 let rnn = RNN::<f32>::with_options(8, 16, 3, RNNNonlinearity::Tanh).unwrap();
3126 let input = ferrotorch_core::zeros::<f32>(&[4, 7, 8]).unwrap();
3127
3128 let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3129
3130 assert_eq!(output.shape(), &[4, 7, 16]);
3131 assert_eq!(h_n.shape(), &[3, 4, 16]);
3132 }
3133
3134 #[test]
3135 fn test_rnn_module_forward_shape() {
3136 let rnn = RNN::<f32>::new(10, 20).unwrap();
3137 let input = ferrotorch_core::zeros::<f32>(&[2, 5, 10]).unwrap();
3138 let output = <RNN<f32> as Module<f32>>::forward(&rnn, &input).unwrap();
3139 assert_eq!(output.shape(), &[2, 5, 20]);
3140 }
3141
3142 #[test]
3143 fn test_rnn_forward_does_not_error() {
3144 let rnn = RNN::<f32>::with_options(4, 8, 2, RNNNonlinearity::Tanh).unwrap();
3145 let input = ferrotorch_core::randn::<f32>(&[3, 10, 4]).unwrap();
3146 let result = rnn.forward_with_state(&input, None);
3147 assert!(
3148 result.is_ok(),
3149 "forward should not error: {:?}",
3150 result.err()
3151 );
3152 }
3153
3154 #[test]
3155 fn test_rnn_forward_nonzero_output() {
3156 let rnn = RNN::<f32>::new(4, 8).unwrap();
3157 let input = ferrotorch_core::randn::<f32>(&[1, 3, 4]).unwrap();
3158 let (output, _) = rnn.forward_with_state(&input, None).unwrap();
3159 let data = output.data().unwrap();
3160 assert!(data.iter().any(|&v| v.abs() > 1e-10));
3161 }
3162
3163 #[test]
3164 fn test_rnn_forward_seq_len_1() {
3165 let rnn = RNN::<f32>::new(4, 8).unwrap();
3166 let input = ferrotorch_core::zeros::<f32>(&[1, 1, 4]).unwrap();
3167 let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3168 assert_eq!(output.shape(), &[1, 1, 8]);
3169 assert_eq!(h_n.shape(), &[1, 1, 8]);
3170 }
3171
3172 #[test]
3173 fn test_rnn_forward_with_initial_state() {
3174 let rnn = RNN::<f32>::new(4, 8).unwrap();
3175 let h0 = ferrotorch_core::zeros::<f32>(&[1, 2, 8]).unwrap();
3176 let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
3177 let result = rnn.forward_with_state(&input, Some(&h0));
3178 assert!(result.is_ok());
3179 }
3180
3181 #[test]
3182 fn test_rnn_forward_state_shape_mismatch() {
3183 let rnn = RNN::<f32>::new(4, 8).unwrap();
3184 let h0 = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap(); let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
3186 assert!(rnn.forward_with_state(&input, Some(&h0)).is_err());
3187 }
3188
3189 #[test]
3190 fn test_rnn_forward_input_wrong_ndim() {
3191 let rnn = RNN::<f32>::new(4, 8).unwrap();
3192 let input = ferrotorch_core::zeros::<f32>(&[10, 4]).unwrap();
3193 assert!(rnn.forward_with_state(&input, None).is_err());
3194 }
3195
3196 #[test]
3197 fn test_rnn_forward_input_size_mismatch() {
3198 let rnn = RNN::<f32>::new(4, 8).unwrap();
3199 let input = ferrotorch_core::zeros::<f32>(&[1, 5, 7]).unwrap();
3200 assert!(rnn.forward_with_state(&input, None).is_err());
3201 }
3202
3203 #[test]
3204 fn test_rnn_named_parameters() {
3205 let rnn = RNN::<f32>::with_options(4, 8, 2, RNNNonlinearity::Tanh).unwrap();
3206 let named = rnn.named_parameters();
3207 assert_eq!(named.len(), 8);
3208 assert_eq!(named[0].0, "layers.0.weight_ih");
3209 assert_eq!(named[1].0, "layers.0.weight_hh");
3210 assert_eq!(named[2].0, "layers.0.bias_ih");
3211 assert_eq!(named[3].0, "layers.0.bias_hh");
3212 assert_eq!(named[4].0, "layers.1.weight_ih");
3213 assert_eq!(named[5].0, "layers.1.weight_hh");
3214 assert_eq!(named[6].0, "layers.1.bias_ih");
3215 assert_eq!(named[7].0, "layers.1.bias_hh");
3216 }
3217
3218 #[test]
3219 fn test_rnn_train_eval() {
3220 let mut rnn = RNN::<f32>::new(4, 8).unwrap();
3221 assert!(rnn.is_training());
3222 rnn.eval();
3223 assert!(!rnn.is_training());
3224 rnn.train();
3225 assert!(rnn.is_training());
3226 }
3227
3228 #[test]
3229 fn test_rnn_all_parameters_require_grad() {
3230 let rnn = RNN::<f32>::new(4, 8).unwrap();
3231 for param in rnn.parameters() {
3232 assert!(param.requires_grad());
3233 }
3234 }
3235
3236 #[test]
3237 fn test_rnn_deterministic() {
3238 let rnn = RNN::<f32>::new(4, 8).unwrap();
3239 let input = ferrotorch_core::from_slice::<f32>(
3240 &[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
3241 &[1, 2, 4],
3242 )
3243 .unwrap();
3244
3245 let (out1, _) = rnn.forward_with_state(&input, None).unwrap();
3246 let (out2, _) = rnn.forward_with_state(&input, None).unwrap();
3247
3248 let d1 = out1.data().unwrap();
3249 let d2 = out2.data().unwrap();
3250 for (i, (&a, &b)) in d1.iter().zip(d2.iter()).enumerate() {
3251 assert!(
3252 (a - b).abs() < 1e-6,
3253 "output mismatch at index {i}: {a} vs {b}"
3254 );
3255 }
3256 }
3257
3258 #[test]
3259 fn test_rnn_state_dict_roundtrip() {
3260 let rnn = RNN::<f32>::new(4, 8).unwrap();
3261 let sd = rnn.state_dict();
3262 assert_eq!(sd.len(), 4);
3263 assert!(sd.contains_key("layers.0.weight_ih"));
3264
3265 let mut rnn2 = RNN::<f32>::new(4, 8).unwrap();
3266 rnn2.load_state_dict(&sd, true).unwrap();
3267 }
3268
3269 #[test]
3270 fn test_rnn_relu_forward() {
3271 let rnn = RNN::<f32>::with_options(4, 8, 1, RNNNonlinearity::ReLU).unwrap();
3272 let input = ferrotorch_core::randn::<f32>(&[2, 3, 4]).unwrap();
3273 let (output, h_n) = rnn.forward_with_state(&input, None).unwrap();
3274 assert_eq!(output.shape(), &[2, 3, 8]);
3275 assert_eq!(h_n.shape(), &[1, 2, 8]);
3276 }
3277
3278 #[test]
3279 fn test_rnn_is_send_sync() {
3280 fn assert_send_sync<T: Send + Sync>() {}
3281 assert_send_sync::<RNN<f32>>();
3282 assert_send_sync::<RNN<f64>>();
3283 }
3284}