1use ferrotorch_core::grad_fns::activation as act;
25use ferrotorch_core::grad_fns::arithmetic;
26use ferrotorch_core::grad_fns::transcendental;
27use ferrotorch_core::ops::elementwise::unary_map;
28use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, normalize_axis};
29
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33macro_rules! impl_activation_module {
38 ($ty:ident) => {
39 impl<T: Float> Module<T> for $ty {
40 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
41 self.forward(input)
42 }
43
44 fn parameters(&self) -> Vec<&Parameter<T>> {
45 vec![]
46 }
47
48 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
49 vec![]
50 }
51
52 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
53 vec![]
54 }
55
56 fn train(&mut self) {
57 self.training = true;
58 }
59
60 fn eval(&mut self) {
61 self.training = false;
62 }
63
64 fn is_training(&self) -> bool {
65 self.training
66 }
67 }
68 };
69}
70
71#[derive(Debug, Clone)]
79pub struct ReLU {
80 training: bool,
81}
82
83impl ReLU {
84 pub fn new() -> Self {
86 Self { training: true }
87 }
88
89 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
91 act::relu(input)
92 }
93}
94
95impl Default for ReLU {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl_activation_module!(ReLU);
102
103#[derive(Debug, Clone)]
113pub struct Softmax2d {
114 training: bool,
115}
116
117impl Softmax2d {
118 pub fn new() -> Self {
119 Self { training: true }
120 }
121
122 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
123 if input.ndim() != 4 {
124 return Err(ferrotorch_core::error::FerrotorchError::InvalidArgument {
125 message: format!(
126 "Softmax2d expects 4-D input [N,C,H,W], got {:?}",
127 input.shape()
128 ),
129 });
130 }
131
132 let shape = input.shape();
133 let n = shape[0];
134 let c = shape[1];
135 let h = shape[2];
136 let w = shape[3];
137
138 if input.is_cuda() {
143 if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
144 let handle = backend.softmax2d_f32(input.gpu_handle()?, n, c, h * w)?;
145 return Tensor::from_storage(
146 ferrotorch_core::storage::TensorStorage::gpu(handle),
147 shape.to_vec(),
148 false,
149 );
150 }
151 return Err(
152 ferrotorch_core::error::FerrotorchError::NotImplementedOnCuda { op: "Softmax2d" },
153 );
154 }
155
156 let data = input.data()?;
157 let mut out = vec![<T as num_traits::Zero>::zero(); n * c * h * w];
158
159 for batch in 0..n {
161 for row in 0..h {
162 for col in 0..w {
163 let mut max_val = T::neg_infinity();
165 for ch in 0..c {
166 let idx = batch * c * h * w + ch * h * w + row * w + col;
167 if data[idx] > max_val {
168 max_val = data[idx];
169 }
170 }
171 let mut sum_exp = <T as num_traits::Zero>::zero();
173 for ch in 0..c {
174 let idx = batch * c * h * w + ch * h * w + row * w + col;
175 let e = (data[idx] - max_val).exp();
176 out[idx] = e;
177 sum_exp += e;
178 }
179 for ch in 0..c {
181 let idx = batch * c * h * w + ch * h * w + row * w + col;
182 out[idx] = out[idx] / sum_exp;
183 }
184 }
185 }
186 }
187
188 Tensor::from_storage(
189 ferrotorch_core::storage::TensorStorage::cpu(out),
190 shape.to_vec(),
191 false,
192 )
193 }
194}
195
196impl Default for Softmax2d {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl_activation_module!(Softmax2d);
203
204pub use act::GeluApproximate;
209
210#[derive(Debug, Clone)]
218pub struct GELU {
219 approximate: GeluApproximate,
220 training: bool,
221}
222
223impl GELU {
224 pub fn new() -> Self {
226 Self {
227 approximate: GeluApproximate::default(),
228 training: true,
229 }
230 }
231
232 pub fn with_approximate(approximate: GeluApproximate) -> Self {
234 Self {
235 approximate,
236 training: true,
237 }
238 }
239
240 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
242 act::gelu_with(input, self.approximate)
243 }
244}
245
246impl Default for GELU {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252impl_activation_module!(GELU);
253
254#[derive(Debug, Clone)]
262pub struct SiLU {
263 training: bool,
264}
265
266impl SiLU {
267 pub fn new() -> Self {
269 Self { training: true }
270 }
271
272 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
274 act::silu(input)
275 }
276}
277
278impl Default for SiLU {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284impl_activation_module!(SiLU);
285
286#[derive(Debug, Clone)]
294pub struct Sigmoid {
295 training: bool,
296}
297
298impl Sigmoid {
299 pub fn new() -> Self {
301 Self { training: true }
302 }
303
304 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
306 act::sigmoid(input)
307 }
308}
309
310impl Default for Sigmoid {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316impl_activation_module!(Sigmoid);
317
318#[derive(Debug, Clone)]
324pub struct Tanh {
325 training: bool,
326}
327
328impl Tanh {
329 pub fn new() -> Self {
331 Self { training: true }
332 }
333
334 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
336 act::tanh(input)
337 }
338}
339
340impl Default for Tanh {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346impl_activation_module!(Tanh);
347
348#[derive(Debug, Clone)]
358pub struct Softmax {
359 pub dim: isize,
361 training: bool,
362}
363
364impl Softmax {
365 pub fn new(dim: isize) -> Self {
369 Self {
370 dim,
371 training: true,
372 }
373 }
374
375 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
377 let ndim = input.ndim();
378 if ndim == 0 {
379 return act::softmax(input);
381 }
382
383 let axis = normalize_axis(self.dim, ndim)?;
384 if axis != ndim - 1 {
385 return Err(FerrotorchError::InvalidArgument {
386 message: format!(
387 "Softmax currently only supports dim=-1 (last axis), \
388 but got dim={} (axis={}) for a {}-D tensor",
389 self.dim, axis, ndim,
390 ),
391 });
392 }
393
394 act::softmax(input)
395 }
396}
397
398impl Default for Softmax {
399 fn default() -> Self {
400 Self::new(-1)
401 }
402}
403
404impl_activation_module!(Softmax);
405
406#[derive(Debug, Clone)]
415pub struct LogSoftmax {
416 pub dim: isize,
418 training: bool,
419}
420
421impl LogSoftmax {
422 pub fn new(dim: isize) -> Self {
424 Self {
425 dim,
426 training: true,
427 }
428 }
429
430 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
432 let ndim = input.ndim();
433 if ndim == 0 {
434 return act::log_softmax(input);
435 }
436
437 let axis = normalize_axis(self.dim, ndim)?;
438 if axis != ndim - 1 {
439 return Err(FerrotorchError::InvalidArgument {
440 message: format!(
441 "LogSoftmax currently only supports dim=-1 (last axis), \
442 but got dim={} (axis={}) for a {}-D tensor",
443 self.dim, axis, ndim,
444 ),
445 });
446 }
447
448 act::log_softmax(input)
449 }
450}
451
452impl Default for LogSoftmax {
453 fn default() -> Self {
454 Self::new(-1)
455 }
456}
457
458impl_activation_module!(LogSoftmax);
459
460#[derive(Debug, Clone)]
475pub struct LeakyReLU {
476 pub negative_slope: f64,
478 training: bool,
479}
480
481impl LeakyReLU {
482 pub fn new(negative_slope: f64) -> Self {
484 Self {
485 negative_slope,
486 training: true,
487 }
488 }
489
490 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
495 if (self.negative_slope - 0.0).abs() < f64::EPSILON {
496 return act::relu(input);
498 }
499 if (self.negative_slope - 1.0).abs() < f64::EPSILON {
500 return Ok(input.clone());
502 }
503
504 let relu_x = act::relu(input)?;
506
507 let scale = T::from(1.0 - self.negative_slope).unwrap();
509 let slope = T::from(self.negative_slope).unwrap();
510
511 let scale_tensor = ferrotorch_core::scalar(scale)?;
513 let slope_tensor = ferrotorch_core::scalar(slope)?;
515
516 let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
518 let scaled_x = arithmetic::mul(input, &slope_tensor)?;
519 arithmetic::add(&scaled_relu, &scaled_x)
520 }
521}
522
523impl Default for LeakyReLU {
524 fn default() -> Self {
525 Self::new(0.01)
526 }
527}
528
529impl_activation_module!(LeakyReLU);
530
531#[derive(Debug, Clone)]
544pub struct ELU {
545 pub alpha: f64,
547 training: bool,
548}
549
550impl ELU {
551 pub fn new(alpha: f64) -> Self {
553 Self {
554 alpha,
555 training: true,
556 }
557 }
558
559 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
561 act::elu(input, self.alpha)
562 }
563}
564
565impl Default for ELU {
566 fn default() -> Self {
567 Self::new(1.0)
568 }
569}
570
571impl_activation_module!(ELU);
572
573#[derive(Debug, Clone)]
585pub struct Mish {
586 training: bool,
587}
588
589impl Mish {
590 pub fn new() -> Self {
592 Self { training: true }
593 }
594
595 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
597 act::mish(input)
598 }
599}
600
601impl Default for Mish {
602 fn default() -> Self {
603 Self::new()
604 }
605}
606
607impl_activation_module!(Mish);
608
609#[derive(Debug, Clone)]
620pub struct PReLU<T: Float> {
621 pub alpha: Parameter<T>,
623 training: bool,
624}
625
626impl<T: Float> PReLU<T> {
627 pub fn new(init_alpha: f64) -> FerrotorchResult<Self> {
629 let alpha_val = T::from(init_alpha).unwrap();
630 let alpha_tensor = ferrotorch_core::from_slice(&[alpha_val], &[1])?;
631 Ok(Self {
632 alpha: Parameter::new(alpha_tensor),
633 training: true,
634 })
635 }
636
637 pub fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
642 if self.alpha.tensor().is_cuda() {
643 return Err(
644 ferrotorch_core::error::FerrotorchError::NotImplementedOnCuda { op: "PReLU" },
645 );
646 }
647 act::prelu(input, self.alpha.tensor())
648 }
649}
650
651impl<T: Float> Module<T> for PReLU<T> {
652 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
653 self.forward(input)
654 }
655
656 fn parameters(&self) -> Vec<&Parameter<T>> {
657 vec![&self.alpha]
658 }
659
660 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
661 vec![&mut self.alpha]
662 }
663
664 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
665 vec![("alpha".to_string(), &self.alpha)]
666 }
667
668 fn train(&mut self) {
669 self.training = true;
670 }
671
672 fn eval(&mut self) {
673 self.training = false;
674 }
675
676 fn is_training(&self) -> bool {
677 self.training
678 }
679}
680
681#[derive(Debug, Clone)]
693pub struct CELU {
694 pub alpha: f64,
696 training: bool,
697}
698
699impl CELU {
700 pub fn new(alpha: f64) -> Self {
702 Self {
703 alpha,
704 training: true,
705 }
706 }
707
708 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
710 let zero = <T as num_traits::Zero>::zero();
711 let one = <T as num_traits::One>::one();
712 let alpha = T::from(self.alpha).unwrap();
713
714 unary_map(input, |x| {
715 let pos = if x > zero { x } else { zero };
716 let neg = if x < zero {
717 alpha * ((x / alpha).exp() - one)
718 } else {
719 zero
720 };
721 pos + neg
722 })
723 }
724}
725
726impl Default for CELU {
727 fn default() -> Self {
728 Self::new(1.0)
729 }
730}
731
732impl_activation_module!(CELU);
733
734#[derive(Debug, Clone)]
749pub struct SELU {
750 training: bool,
751}
752
753const SELU_ALPHA: f64 = 1.6732632423543772;
755const SELU_LAMBDA: f64 = 1.0507009873554805;
757
758impl SELU {
759 pub fn new() -> Self {
761 Self { training: true }
762 }
763
764 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
766 let zero = <T as num_traits::Zero>::zero();
767 let one = <T as num_traits::One>::one();
768 let alpha = T::from(SELU_ALPHA).unwrap();
769 let lambda = T::from(SELU_LAMBDA).unwrap();
770
771 unary_map(input, |x| {
772 if x > zero {
773 lambda * x
774 } else {
775 lambda * alpha * (x.exp() - one)
776 }
777 })
778 }
779}
780
781impl Default for SELU {
782 fn default() -> Self {
783 Self::new()
784 }
785}
786
787impl_activation_module!(SELU);
788
789#[derive(Debug, Clone)]
799pub struct HardSigmoid {
800 training: bool,
801}
802
803impl HardSigmoid {
804 pub fn new() -> Self {
806 Self { training: true }
807 }
808
809 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
811 let zero = <T as num_traits::Zero>::zero();
812 let one = <T as num_traits::One>::one();
813 let three = T::from(3.0).unwrap();
814 let six = T::from(6.0).unwrap();
815
816 unary_map(input, |x| {
817 let v = (x + three) / six;
818 if v < zero {
819 zero
820 } else if v > one {
821 one
822 } else {
823 v
824 }
825 })
826 }
827}
828
829impl Default for HardSigmoid {
830 fn default() -> Self {
831 Self::new()
832 }
833}
834
835impl_activation_module!(HardSigmoid);
836
837#[derive(Debug, Clone)]
847pub struct HardSwish {
848 training: bool,
849}
850
851impl HardSwish {
852 pub fn new() -> Self {
854 Self { training: true }
855 }
856
857 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
859 let zero = <T as num_traits::Zero>::zero();
860 let one = <T as num_traits::One>::one();
861 let three = T::from(3.0).unwrap();
862 let six = T::from(6.0).unwrap();
863
864 unary_map(input, |x| {
865 let hard_sig = {
866 let v = (x + three) / six;
867 if v < zero {
868 zero
869 } else if v > one {
870 one
871 } else {
872 v
873 }
874 };
875 x * hard_sig
876 })
877 }
878}
879
880impl Default for HardSwish {
881 fn default() -> Self {
882 Self::new()
883 }
884}
885
886impl_activation_module!(HardSwish);
887
888#[derive(Debug, Clone)]
899pub struct Softplus {
900 pub beta: f64,
902 pub threshold: f64,
905 training: bool,
906}
907
908impl Softplus {
909 pub fn new(beta: f64) -> Self {
911 Self {
912 beta,
913 threshold: 20.0,
914 training: true,
915 }
916 }
917
918 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
922 act::softplus(input, self.beta, self.threshold)
923 }
924}
925
926impl Default for Softplus {
927 fn default() -> Self {
928 Self::new(1.0)
929 }
930}
931
932impl_activation_module!(Softplus);
933
934#[derive(Debug, Clone)]
947pub struct GLU {
948 training: bool,
949}
950
951impl GLU {
952 pub fn new() -> Self {
954 Self { training: true }
955 }
956
957 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
962 let shape = input.shape();
963 let ndim = shape.len();
964 if ndim == 0 {
965 return Err(FerrotorchError::InvalidArgument {
966 message: "GLU requires at least 1D input".to_string(),
967 });
968 }
969
970 let last_dim = shape[ndim - 1];
971 if last_dim % 2 != 0 {
972 return Err(FerrotorchError::InvalidArgument {
973 message: format!(
974 "GLU requires the last dimension to be even, got {}",
975 last_dim
976 ),
977 });
978 }
979
980 let half = last_dim / 2;
981 let device = input.device();
982 let data = input.data_vec()?;
983
984 let outer_size: usize = shape[..ndim - 1].iter().product();
987 let outer_size = if outer_size == 0 { 1 } else { outer_size };
988
989 let one = <T as num_traits::One>::one();
990
991 let mut result = Vec::with_capacity(outer_size * half);
992 for i in 0..outer_size {
993 let base = i * last_dim;
994 for j in 0..half {
995 let a = data[base + j];
996 let b = data[base + half + j];
997 let sig_b = one / (one + (-b).exp());
998 result.push(a * sig_b);
999 }
1000 }
1001
1002 let mut out_shape = shape.to_vec();
1003 out_shape[ndim - 1] = half;
1004
1005 let out = Tensor::from_storage(
1006 ferrotorch_core::TensorStorage::cpu(result),
1007 out_shape,
1008 false,
1009 )?;
1010 if device.is_cuda() {
1011 out.to(device)
1012 } else {
1013 Ok(out)
1014 }
1015 }
1016}
1017
1018impl Default for GLU {
1019 fn default() -> Self {
1020 Self::new()
1021 }
1022}
1023
1024impl_activation_module!(GLU);
1025
1026#[derive(Debug, Clone)]
1036pub struct ReLU6 {
1037 training: bool,
1038}
1039
1040impl ReLU6 {
1041 pub fn new() -> Self {
1043 Self { training: true }
1044 }
1045
1046 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1048 let zero = <T as num_traits::Zero>::zero();
1049 let six = T::from(6.0).unwrap();
1050 transcendental::clamp(input, zero, six)
1051 }
1052}
1053
1054impl Default for ReLU6 {
1055 fn default() -> Self {
1056 Self::new()
1057 }
1058}
1059
1060impl_activation_module!(ReLU6);
1061
1062#[derive(Debug, Clone)]
1076pub struct Hardtanh {
1077 pub min_val: f64,
1079 pub max_val: f64,
1081 training: bool,
1082}
1083
1084impl Hardtanh {
1085 pub fn new(min_val: f64, max_val: f64) -> Self {
1087 Self {
1088 min_val,
1089 max_val,
1090 training: true,
1091 }
1092 }
1093
1094 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1096 let min = T::from(self.min_val).unwrap();
1097 let max = T::from(self.max_val).unwrap();
1098 transcendental::clamp(input, min, max)
1099 }
1100}
1101
1102impl Default for Hardtanh {
1103 fn default() -> Self {
1104 Self::new(-1.0, 1.0)
1105 }
1106}
1107
1108impl_activation_module!(Hardtanh);
1109
1110#[derive(Debug, Clone)]
1120pub struct LogSigmoid {
1121 training: bool,
1122}
1123
1124impl LogSigmoid {
1125 pub fn new() -> Self {
1127 Self { training: true }
1128 }
1129
1130 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1135 let neg_input = arithmetic::neg(input)?;
1137 let sp = act::softplus(&neg_input, 1.0, 20.0)?;
1138 arithmetic::neg(&sp)
1139 }
1140}
1141
1142impl Default for LogSigmoid {
1143 fn default() -> Self {
1144 Self::new()
1145 }
1146}
1147
1148impl_activation_module!(LogSigmoid);
1149
1150#[derive(Debug, Clone)]
1159pub struct Softmin {
1160 pub dim: isize,
1162 training: bool,
1163}
1164
1165impl Softmin {
1166 pub fn new(dim: isize) -> Self {
1168 Self {
1169 dim,
1170 training: true,
1171 }
1172 }
1173
1174 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1176 let ndim = input.ndim();
1177 if ndim == 0 {
1178 let neg_input = arithmetic::neg(input)?;
1179 return act::softmax(&neg_input);
1180 }
1181
1182 let axis = normalize_axis(self.dim, ndim)?;
1183 if axis != ndim - 1 {
1184 return Err(FerrotorchError::InvalidArgument {
1185 message: format!(
1186 "Softmin currently only supports dim=-1 (last axis), \
1187 but got dim={} (axis={}) for a {}-D tensor",
1188 self.dim, axis, ndim,
1189 ),
1190 });
1191 }
1192
1193 let neg_input = arithmetic::neg(input)?;
1194 act::softmax(&neg_input)
1195 }
1196}
1197
1198impl Default for Softmin {
1199 fn default() -> Self {
1200 Self::new(-1)
1201 }
1202}
1203
1204impl_activation_module!(Softmin);
1205
1206#[derive(Debug, Clone)]
1219pub struct Threshold {
1220 pub threshold: f64,
1222 pub value: f64,
1224 training: bool,
1225}
1226
1227impl Threshold {
1228 pub fn new(threshold: f64, value: f64) -> Self {
1230 Self {
1231 threshold,
1232 value,
1233 training: true,
1234 }
1235 }
1236
1237 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1239 let thresh = T::from(self.threshold).unwrap();
1240 let val = T::from(self.value).unwrap();
1241 unary_map(input, |x| if x > thresh { x } else { val })
1242 }
1243}
1244
1245impl_activation_module!(Threshold);
1246
1247#[derive(Debug, Clone)]
1261pub struct Softshrink {
1262 pub lambda: f64,
1264 training: bool,
1265}
1266
1267impl Softshrink {
1268 pub fn new(lambda: f64) -> Self {
1270 Self {
1271 lambda,
1272 training: true,
1273 }
1274 }
1275
1276 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1278 let lam = T::from(self.lambda).unwrap();
1279 let neg_lam = T::from(-self.lambda).unwrap();
1280 let zero = <T as num_traits::Zero>::zero();
1281 unary_map(input, |x| {
1282 if x > lam {
1283 x - lam
1284 } else if x < neg_lam {
1285 x + lam
1286 } else {
1287 zero
1288 }
1289 })
1290 }
1291}
1292
1293impl Default for Softshrink {
1294 fn default() -> Self {
1295 Self::new(0.5)
1296 }
1297}
1298
1299impl_activation_module!(Softshrink);
1300
1301#[derive(Debug, Clone)]
1314pub struct Hardshrink {
1315 pub lambda: f64,
1317 training: bool,
1318}
1319
1320impl Hardshrink {
1321 pub fn new(lambda: f64) -> Self {
1323 Self {
1324 lambda,
1325 training: true,
1326 }
1327 }
1328
1329 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1331 let lam = T::from(self.lambda).unwrap();
1332 let neg_lam = T::from(-self.lambda).unwrap();
1333 let zero = <T as num_traits::Zero>::zero();
1334 unary_map(input, |x| if x > lam || x < neg_lam { x } else { zero })
1335 }
1336}
1337
1338impl Default for Hardshrink {
1339 fn default() -> Self {
1340 Self::new(0.5)
1341 }
1342}
1343
1344impl_activation_module!(Hardshrink);
1345
1346#[derive(Debug, Clone)]
1354pub struct Tanhshrink {
1355 training: bool,
1356}
1357
1358impl Tanhshrink {
1359 pub fn new() -> Self {
1361 Self { training: true }
1362 }
1363
1364 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1366 let tanh_x = act::tanh(input)?;
1367 arithmetic::sub(input, &tanh_x)
1368 }
1369}
1370
1371impl Default for Tanhshrink {
1372 fn default() -> Self {
1373 Self::new()
1374 }
1375}
1376
1377impl_activation_module!(Tanhshrink);
1378
1379#[derive(Debug, Clone)]
1387pub struct Softsign {
1388 training: bool,
1389}
1390
1391impl Softsign {
1392 pub fn new() -> Self {
1394 Self { training: true }
1395 }
1396
1397 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1399 let one = <T as num_traits::One>::one();
1400 unary_map(input, |x| x / (one + x.abs()))
1401 }
1402}
1403
1404impl Default for Softsign {
1405 fn default() -> Self {
1406 Self::new()
1407 }
1408}
1409
1410impl_activation_module!(Softsign);
1411
1412#[derive(Debug, Clone)]
1430pub struct RReLU {
1431 pub lower: f64,
1433 pub upper: f64,
1435 training: bool,
1436}
1437
1438fn rrelu_xorshift_seed() -> u64 {
1440 use std::collections::hash_map::DefaultHasher;
1441 use std::hash::{Hash, Hasher};
1442 use std::time::SystemTime;
1443
1444 let mut hasher = DefaultHasher::new();
1445 SystemTime::now().hash(&mut hasher);
1446 std::thread::current().id().hash(&mut hasher);
1447 let mut state = hasher.finish();
1448 if state == 0 {
1449 state = 0xdeadbeefcafe;
1450 }
1451 state
1452}
1453
1454#[inline]
1456fn rrelu_xorshift_next(state: &mut u64) -> f64 {
1457 *state ^= *state << 13;
1458 *state ^= *state >> 7;
1459 *state ^= *state << 17;
1460 (*state as f64) / (u64::MAX as f64)
1461}
1462
1463impl RReLU {
1464 pub fn new(lower: f64, upper: f64) -> Self {
1466 Self {
1467 lower,
1468 upper,
1469 training: true,
1470 }
1471 }
1472
1473 pub fn forward<T: Float>(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1475 let zero = <T as num_traits::Zero>::zero();
1476
1477 if self.training {
1478 let rng_state = std::cell::Cell::new(rrelu_xorshift_seed());
1481 let lower = self.lower;
1482 let upper = self.upper;
1483 let range = upper - lower;
1484
1485 unary_map(input, |x| {
1486 if x >= zero {
1487 x
1488 } else {
1489 let mut st = rng_state.get();
1490 let u = rrelu_xorshift_next(&mut st);
1491 rng_state.set(st);
1492 let slope = T::from(lower + u * range).unwrap();
1493 slope * x
1494 }
1495 })
1496 } else {
1497 let mean_slope = T::from((self.lower + self.upper) / 2.0).unwrap();
1499 unary_map(input, |x| if x >= zero { x } else { mean_slope * x })
1500 }
1501 }
1502}
1503
1504impl Default for RReLU {
1505 fn default() -> Self {
1506 Self::new(1.0 / 8.0, 1.0 / 3.0)
1507 }
1508}
1509
1510impl_activation_module!(RReLU);
1511
1512#[cfg(test)]
1517mod tests {
1518 use super::*;
1519 use ferrotorch_core::TensorStorage;
1520
1521 fn t(data: &[f64]) -> Tensor<f64> {
1523 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], false).unwrap()
1524 }
1525
1526 fn t2d(data: &[f64], rows: usize, cols: usize) -> Tensor<f64> {
1528 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![rows, cols], false).unwrap()
1529 }
1530
1531 fn assert_zero_param_module<M, T: Float>(module: &mut M)
1537 where
1538 M: Module<T>,
1539 {
1540 assert!(module.parameters().is_empty(), "should have no parameters");
1541 assert!(
1542 module.parameters_mut().is_empty(),
1543 "should have no mutable parameters"
1544 );
1545 assert!(
1546 module.named_parameters().is_empty(),
1547 "should have no named parameters"
1548 );
1549 assert!(module.is_training(), "default should be training mode");
1550 module.eval();
1551 assert!(!module.is_training(), "eval() should set training=false");
1552 module.train();
1553 assert!(module.is_training(), "train() should set training=true");
1554 }
1555
1556 #[test]
1561 fn test_relu_forward() {
1562 let m = ReLU::new();
1563 let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1564 let y = m.forward(&x).unwrap();
1565 let d = y.data().unwrap();
1566 assert!((d[0] - 0.0).abs() < 1e-7);
1567 assert!((d[1] - 0.0).abs() < 1e-7);
1568 assert!((d[2] - 0.0).abs() < 1e-7);
1569 assert!((d[3] - 1.0).abs() < 1e-7);
1570 assert!((d[4] - 2.0).abs() < 1e-7);
1571 }
1572
1573 #[test]
1574 fn test_relu_module_trait() {
1575 let mut m = ReLU::new();
1576 assert_zero_param_module::<ReLU, f64>(&mut m);
1577 }
1578
1579 #[test]
1584 fn test_gelu_forward() {
1585 let m = GELU::new();
1586 let x = t(&[0.0]);
1588 let y = m.forward(&x).unwrap();
1589 assert!(y.data().unwrap()[0].abs() < 1e-7);
1590
1591 let x = t(&[1.0, 2.0]);
1593 let y = m.forward(&x).unwrap();
1594 let d = y.data().unwrap();
1595 assert!(d[0] > 0.0);
1596 assert!(d[1] > 0.0);
1597
1598 let x = t(&[10.0]);
1600 let y = m.forward(&x).unwrap();
1601 assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
1602 }
1603
1604 #[test]
1605 fn test_gelu_module_trait() {
1606 let mut m = GELU::new();
1607 assert_zero_param_module::<GELU, f64>(&mut m);
1608 }
1609
1610 #[test]
1615 fn test_silu_forward() {
1616 let m = SiLU::new();
1617 let x = t(&[0.0]);
1619 let y = m.forward(&x).unwrap();
1620 assert!(y.data().unwrap()[0].abs() < 1e-7);
1621
1622 let x = t(&[10.0]);
1624 let y = m.forward(&x).unwrap();
1625 assert!((y.data().unwrap()[0] - 10.0).abs() < 0.01);
1626 }
1627
1628 #[test]
1629 fn test_silu_module_trait() {
1630 let mut m = SiLU::new();
1631 assert_zero_param_module::<SiLU, f64>(&mut m);
1632 }
1633
1634 #[test]
1639 fn test_sigmoid_forward() {
1640 let m = Sigmoid::new();
1641 let x = t(&[0.0]);
1642 let y = m.forward(&x).unwrap();
1643 assert!((y.data().unwrap()[0] - 0.5).abs() < 1e-7);
1644
1645 let x = t(&[-100.0, 100.0]);
1646 let y = m.forward(&x).unwrap();
1647 let d = y.data().unwrap();
1648 assert!(d[0] < 1e-10, "sigmoid(-100) should be ~0");
1649 assert!((d[1] - 1.0).abs() < 1e-10, "sigmoid(100) should be ~1");
1650 }
1651
1652 #[test]
1653 fn test_sigmoid_module_trait() {
1654 let mut m = Sigmoid::new();
1655 assert_zero_param_module::<Sigmoid, f64>(&mut m);
1656 }
1657
1658 #[test]
1663 fn test_tanh_forward() {
1664 let m = Tanh::new();
1665 let x = t(&[0.0]);
1666 let y = m.forward(&x).unwrap();
1667 assert!(y.data().unwrap()[0].abs() < 1e-7);
1668
1669 let x = t(&[-100.0, 100.0]);
1670 let y = m.forward(&x).unwrap();
1671 let d = y.data().unwrap();
1672 assert!((d[0] + 1.0).abs() < 1e-10, "tanh(-100) should be ~-1");
1673 assert!((d[1] - 1.0).abs() < 1e-10, "tanh(100) should be ~1");
1674 }
1675
1676 #[test]
1677 fn test_tanh_module_trait() {
1678 let mut m = Tanh::new();
1679 assert_zero_param_module::<Tanh, f64>(&mut m);
1680 }
1681
1682 #[test]
1687 fn test_softmax_forward_1d() {
1688 let m = Softmax::new(-1);
1689 let x = t(&[1.0, 2.0, 3.0]);
1690 let y = m.forward(&x).unwrap();
1691 let d = y.data().unwrap();
1692
1693 let total: f64 = d.iter().sum();
1695 assert!((total - 1.0).abs() < 1e-7);
1696
1697 assert!(d[0] < d[1]);
1699 assert!(d[1] < d[2]);
1700 }
1701
1702 #[test]
1703 fn test_softmax_forward_2d() {
1704 let m = Softmax::new(-1);
1705 let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
1707 let y = m.forward(&x).unwrap();
1708 let d = y.data().unwrap();
1709
1710 let row0_sum = d[0] + d[1];
1712 let row1_sum = d[2] + d[3];
1713 assert!((row0_sum - 1.0).abs() < 1e-7);
1714 assert!((row1_sum - 1.0).abs() < 1e-7);
1715 }
1716
1717 #[test]
1718 fn test_softmax_wrong_dim() {
1719 let m = Softmax::new(0);
1720 let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
1721 assert!(m.forward(&x).is_err());
1723 }
1724
1725 #[test]
1726 fn test_softmax_module_trait() {
1727 let mut m = Softmax::new(-1);
1728 assert_zero_param_module::<Softmax, f64>(&mut m);
1729 }
1730
1731 #[test]
1736 fn test_log_softmax_forward_1d() {
1737 let m = LogSoftmax::new(-1);
1738 let x = t(&[1.0, 2.0, 3.0]);
1739 let y = m.forward(&x).unwrap();
1740 let d = y.data().unwrap();
1741
1742 let total: f64 = d.iter().map(|&v| v.exp()).sum();
1744 assert!((total - 1.0).abs() < 1e-7, "exp(log_softmax) sum = {total}");
1745
1746 assert!(d.iter().all(|&v| v <= 0.0));
1748 }
1749
1750 #[test]
1751 fn test_log_softmax_module_trait() {
1752 let mut m = LogSoftmax::new(-1);
1753 assert_zero_param_module::<LogSoftmax, f64>(&mut m);
1754 }
1755
1756 #[test]
1761 fn test_leaky_relu_forward() {
1762 let m = LeakyReLU::new(0.01);
1763 let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1764 let y = m.forward(&x).unwrap();
1765 let d = y.data().unwrap();
1766
1767 assert!((d[0] - (-0.02)).abs() < 1e-7, "LeakyReLU(-2) = {}", d[0]);
1768 assert!((d[1] - (-0.01)).abs() < 1e-7, "LeakyReLU(-1) = {}", d[1]);
1769 assert!((d[2] - 0.0).abs() < 1e-7, "LeakyReLU(0) = {}", d[2]);
1770 assert!((d[3] - 1.0).abs() < 1e-7, "LeakyReLU(1) = {}", d[3]);
1771 assert!((d[4] - 2.0).abs() < 1e-7, "LeakyReLU(2) = {}", d[4]);
1772 }
1773
1774 #[test]
1775 fn test_leaky_relu_large_slope() {
1776 let m = LeakyReLU::new(0.2);
1777 let x = t(&[-5.0, 3.0]);
1778 let y = m.forward(&x).unwrap();
1779 let d = y.data().unwrap();
1780
1781 assert!(
1782 (d[0] - (-1.0)).abs() < 1e-7,
1783 "LeakyReLU(-5, slope=0.2) = {}",
1784 d[0]
1785 );
1786 assert!(
1787 (d[1] - 3.0).abs() < 1e-7,
1788 "LeakyReLU(3, slope=0.2) = {}",
1789 d[1]
1790 );
1791 }
1792
1793 #[test]
1794 fn test_leaky_relu_zero_slope_is_relu() {
1795 let m = LeakyReLU::new(0.0);
1796 let x = t(&[-2.0, 0.0, 3.0]);
1797 let y = m.forward(&x).unwrap();
1798 let d = y.data().unwrap();
1799
1800 assert!((d[0] - 0.0).abs() < 1e-7);
1801 assert!((d[1] - 0.0).abs() < 1e-7);
1802 assert!((d[2] - 3.0).abs() < 1e-7);
1803 }
1804
1805 #[test]
1806 fn test_leaky_relu_module_trait() {
1807 let mut m = LeakyReLU::new(0.01);
1808 assert_zero_param_module::<LeakyReLU, f64>(&mut m);
1809 }
1810
1811 #[test]
1816 fn test_elu_forward() {
1817 let m = ELU::new(1.0);
1818 let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1819 let y = m.forward(&x).unwrap();
1820 let d = y.data().unwrap();
1821
1822 assert!((d[3] - 1.0).abs() < 1e-7);
1824 assert!((d[4] - 2.0).abs() < 1e-7);
1825
1826 assert!((d[2] - 0.0).abs() < 1e-7);
1828
1829 let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
1831 assert!(
1832 (d[1] - expected_m1).abs() < 1e-7,
1833 "ELU(-1) expected {}, got {}",
1834 expected_m1,
1835 d[1]
1836 );
1837
1838 let expected_m2 = 1.0 * ((-2.0_f64).exp() - 1.0);
1839 assert!(
1840 (d[0] - expected_m2).abs() < 1e-7,
1841 "ELU(-2) expected {}, got {}",
1842 expected_m2,
1843 d[0]
1844 );
1845
1846 let x = t(&[-100.0]);
1848 let y = m.forward(&x).unwrap();
1849 assert!((y.data().unwrap()[0] + 1.0).abs() < 1e-7);
1850 }
1851
1852 #[test]
1853 fn test_elu_custom_alpha() {
1854 let m = ELU::new(2.0);
1855 let x = t(&[-1.0, 1.0]);
1856 let y = m.forward(&x).unwrap();
1857 let d = y.data().unwrap();
1858
1859 let expected = 2.0 * ((-1.0_f64).exp() - 1.0);
1860 assert!((d[0] - expected).abs() < 1e-7);
1861 assert!((d[1] - 1.0).abs() < 1e-7);
1862 }
1863
1864 #[test]
1865 fn test_elu_module_trait() {
1866 let mut m = ELU::new(1.0);
1867 assert_zero_param_module::<ELU, f64>(&mut m);
1868 }
1869
1870 #[test]
1875 fn test_mish_forward() {
1876 let m = Mish::new();
1877 let x = t(&[0.0]);
1879 let y = m.forward(&x).unwrap();
1880 assert!(y.data().unwrap()[0].abs() < 1e-7, "mish(0) should be 0");
1881
1882 let x = t(&[20.0]);
1884 let y = m.forward(&x).unwrap();
1885 assert!(
1886 (y.data().unwrap()[0] - 20.0).abs() < 0.01,
1887 "mish(20) should be ~20"
1888 );
1889
1890 let x = t(&[-1.0]);
1892 let y = m.forward(&x).unwrap();
1893 let val = y.data().unwrap()[0];
1894 let softplus = (1.0 + (-1.0_f64).exp()).ln();
1895 let expected = -softplus.tanh();
1896 assert!(
1897 (val - expected).abs() < 1e-7,
1898 "mish(-1) expected {}, got {}",
1899 expected,
1900 val
1901 );
1902 }
1903
1904 #[test]
1905 fn test_mish_module_trait() {
1906 let mut m = Mish::new();
1907 assert_zero_param_module::<Mish, f64>(&mut m);
1908 }
1909
1910 #[test]
1919 fn test_prelu_forward_default() {
1920 let m = PReLU::<f64>::new(0.25).unwrap();
1921 let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1922 let y = m.forward(&x).unwrap();
1923 let d = y.data().unwrap();
1924 assert!((d[0] - (-0.5)).abs() < 1e-6, "PReLU(-2) = {}", d[0]);
1926 assert!((d[1] - (-0.25)).abs() < 1e-6, "PReLU(-1) = {}", d[1]);
1927 assert!((d[2] - 0.0).abs() < 1e-6, "PReLU(0) = {}", d[2]);
1928 assert!((d[3] - 1.0).abs() < 1e-6, "PReLU(1) = {}", d[3]);
1929 assert!((d[4] - 2.0).abs() < 1e-6, "PReLU(2) = {}", d[4]);
1930 }
1931
1932 #[test]
1933 fn test_prelu_has_parameter() {
1934 let m = PReLU::<f64>::new(0.25).unwrap();
1935 assert_eq!(m.parameters().len(), 1, "PReLU should have 1 parameter");
1936 let named = m.named_parameters();
1937 assert_eq!(named.len(), 1);
1938 assert_eq!(named[0].0, "alpha");
1939 }
1940
1941 #[test]
1942 fn test_prelu_module_trait() {
1943 let mut m = PReLU::<f64>::new(0.25).unwrap();
1944 assert_eq!(m.parameters().len(), 1);
1945 assert!(m.is_training());
1946 m.eval();
1947 assert!(!m.is_training());
1948 m.train();
1949 assert!(m.is_training());
1950 }
1951
1952 #[test]
1957 fn test_celu_forward() {
1958 let m = CELU::new(1.0);
1959 let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
1960 let y = m.forward(&x).unwrap();
1961 let d = y.data().unwrap();
1962
1963 assert!((d[3] - 1.0).abs() < 1e-7);
1965 assert!((d[4] - 2.0).abs() < 1e-7);
1966 assert!((d[2] - 0.0).abs() < 1e-7);
1967
1968 let expected_m1 = 1.0 * ((-1.0_f64).exp() - 1.0);
1970 assert!((d[1] - expected_m1).abs() < 1e-7, "CELU(-1) = {}", d[1]);
1971 }
1972
1973 #[test]
1974 fn test_celu_module_trait() {
1975 let mut m = CELU::new(1.0);
1976 assert_zero_param_module::<CELU, f64>(&mut m);
1977 }
1978
1979 #[test]
1984 fn test_selu_forward() {
1985 let m = SELU::new();
1986 let x = t(&[-1.0, 0.0, 1.0]);
1987 let y = m.forward(&x).unwrap();
1988 let d = y.data().unwrap();
1989
1990 let lambda = 1.0507009873554805_f64;
1992 let alpha = 1.6732632423543772_f64;
1993 assert!((d[2] - lambda * 1.0).abs() < 1e-7, "SELU(1) = {}", d[2]);
1994 assert!((d[1] - 0.0).abs() < 1e-7, "SELU(0) = {}", d[1]);
1995
1996 let expected_m1 = lambda * alpha * ((-1.0_f64).exp() - 1.0);
1998 assert!((d[0] - expected_m1).abs() < 1e-7, "SELU(-1) = {}", d[0]);
1999 }
2000
2001 #[test]
2002 fn test_selu_module_trait() {
2003 let mut m = SELU::new();
2004 assert_zero_param_module::<SELU, f64>(&mut m);
2005 }
2006
2007 #[test]
2012 fn test_hard_sigmoid_forward() {
2013 let m = HardSigmoid::new();
2014 let x = t(&[-4.0, -3.0, 0.0, 3.0, 5.0]);
2021 let y = m.forward(&x).unwrap();
2022 let d = y.data().unwrap();
2023 assert!((d[0] - 0.0).abs() < 1e-7, "HardSigmoid(-4) = {}", d[0]);
2024 assert!((d[1] - 0.0).abs() < 1e-7, "HardSigmoid(-3) = {}", d[1]);
2025 assert!((d[2] - 0.5).abs() < 1e-7, "HardSigmoid(0) = {}", d[2]);
2026 assert!((d[3] - 1.0).abs() < 1e-7, "HardSigmoid(3) = {}", d[3]);
2027 assert!((d[4] - 1.0).abs() < 1e-7, "HardSigmoid(5) = {}", d[4]);
2028 }
2029
2030 #[test]
2031 fn test_hard_sigmoid_module_trait() {
2032 let mut m = HardSigmoid::new();
2033 assert_zero_param_module::<HardSigmoid, f64>(&mut m);
2034 }
2035
2036 #[test]
2041 fn test_hard_swish_forward() {
2042 let m = HardSwish::new();
2043 let x = t(&[-4.0, 0.0, 3.0, 5.0, -1.0]);
2050 let y = m.forward(&x).unwrap();
2051 let d = y.data().unwrap();
2052 assert!((d[0] - 0.0).abs() < 1e-7, "HardSwish(-4) = {}", d[0]);
2053 assert!((d[1] - 0.0).abs() < 1e-7, "HardSwish(0) = {}", d[1]);
2054 assert!((d[2] - 3.0).abs() < 1e-7, "HardSwish(3) = {}", d[2]);
2055 assert!((d[3] - 5.0).abs() < 1e-7, "HardSwish(5) = {}", d[3]);
2056 assert!(
2057 (d[4] - (-1.0 / 3.0)).abs() < 1e-7,
2058 "HardSwish(-1) = {}",
2059 d[4]
2060 );
2061 }
2062
2063 #[test]
2064 fn test_hard_swish_module_trait() {
2065 let mut m = HardSwish::new();
2066 assert_zero_param_module::<HardSwish, f64>(&mut m);
2067 }
2068
2069 #[test]
2074 fn test_softplus_forward() {
2075 let m = Softplus::new(1.0);
2076 let x = t(&[0.0]);
2078 let y = m.forward(&x).unwrap();
2079 let d = y.data().unwrap();
2080 assert!((d[0] - 2.0_f64.ln()).abs() < 1e-7, "Softplus(0) = {}", d[0]);
2081
2082 let x = t(&[25.0]);
2084 let y = m.forward(&x).unwrap();
2085 let d = y.data().unwrap();
2086 assert!((d[0] - 25.0).abs() < 1e-5, "Softplus(25) = {}", d[0]);
2087
2088 let x = t(&[1.0]);
2090 let y = m.forward(&x).unwrap();
2091 let d = y.data().unwrap();
2092 let expected = (1.0 + 1.0_f64.exp()).ln();
2093 assert!((d[0] - expected).abs() < 1e-7, "Softplus(1) = {}", d[0]);
2094 }
2095
2096 #[test]
2097 fn test_softplus_custom_beta() {
2098 let m = Softplus::new(2.0);
2099 let x = t(&[0.0]);
2101 let y = m.forward(&x).unwrap();
2102 let d = y.data().unwrap();
2103 let expected = 2.0_f64.ln() / 2.0;
2104 assert!(
2105 (d[0] - expected).abs() < 1e-7,
2106 "Softplus(0, beta=2) = {}",
2107 d[0]
2108 );
2109 }
2110
2111 #[test]
2112 fn test_softplus_module_trait() {
2113 let mut m = Softplus::new(1.0);
2114 assert_zero_param_module::<Softplus, f64>(&mut m);
2115 }
2116
2117 #[test]
2122 fn test_glu_forward_1d() {
2123 let m = GLU::new();
2124 let x = t(&[1.0, 0.0, 2.0, 0.0]);
2128 let y = m.forward(&x).unwrap();
2129 assert_eq!(y.shape(), &[2]);
2130 let d = y.data().unwrap();
2131 let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
2132 assert!((d[0] - sig_2).abs() < 1e-7, "GLU[0] = {}", d[0]);
2133 assert!((d[1] - 0.0).abs() < 1e-7, "GLU[1] = {}", d[1]);
2134 }
2135
2136 #[test]
2137 fn test_glu_forward_2d() {
2138 let m = GLU::new();
2139 let x = t2d(&[1.0, 0.0, 2.0, 0.0], 1, 4);
2141 let y = m.forward(&x).unwrap();
2142 assert_eq!(y.shape(), &[1, 2]);
2143 let d = y.data().unwrap();
2144 let sig_2 = 1.0 / (1.0 + (-2.0_f64).exp());
2145 assert!((d[0] - sig_2).abs() < 1e-7);
2146 assert!((d[1] - 0.0).abs() < 1e-7);
2147 }
2148
2149 #[test]
2150 fn test_glu_odd_dim_error() {
2151 let m = GLU::new();
2152 let x = t(&[1.0, 2.0, 3.0]); assert!(m.forward(&x).is_err());
2154 }
2155
2156 #[test]
2157 fn test_glu_module_trait() {
2158 let mut m = GLU::new();
2159 assert_zero_param_module::<GLU, f64>(&mut m);
2160 }
2161
2162 #[test]
2167 fn test_relu6_forward() {
2168 let m = ReLU6::new();
2169 let x = t(&[-2.0, 0.0, 3.0, 6.0, 10.0]);
2170 let y = m.forward(&x).unwrap();
2171 let d = y.data().unwrap();
2172 assert!((d[0] - 0.0).abs() < 1e-7, "ReLU6(-2) = {}", d[0]);
2173 assert!((d[1] - 0.0).abs() < 1e-7, "ReLU6(0) = {}", d[1]);
2174 assert!((d[2] - 3.0).abs() < 1e-7, "ReLU6(3) = {}", d[2]);
2175 assert!((d[3] - 6.0).abs() < 1e-7, "ReLU6(6) = {}", d[3]);
2176 assert!((d[4] - 6.0).abs() < 1e-7, "ReLU6(10) = {}", d[4]);
2177 }
2178
2179 #[test]
2180 fn test_relu6_module_trait() {
2181 let mut m = ReLU6::new();
2182 assert_zero_param_module::<ReLU6, f64>(&mut m);
2183 }
2184
2185 #[test]
2190 fn test_hardtanh_forward_default() {
2191 let m = Hardtanh::default();
2192 let x = t(&[-5.0, -1.0, 0.0, 0.5, 1.0, 3.0]);
2194 let y = m.forward(&x).unwrap();
2195 let d = y.data().unwrap();
2196 assert!((d[0] - (-1.0)).abs() < 1e-7, "Hardtanh(-5) = {}", d[0]);
2197 assert!((d[1] - (-1.0)).abs() < 1e-7, "Hardtanh(-1) = {}", d[1]);
2198 assert!((d[2] - 0.0).abs() < 1e-7, "Hardtanh(0) = {}", d[2]);
2199 assert!((d[3] - 0.5).abs() < 1e-7, "Hardtanh(0.5) = {}", d[3]);
2200 assert!((d[4] - 1.0).abs() < 1e-7, "Hardtanh(1) = {}", d[4]);
2201 assert!((d[5] - 1.0).abs() < 1e-7, "Hardtanh(3) = {}", d[5]);
2202 }
2203
2204 #[test]
2205 fn test_hardtanh_custom_range() {
2206 let m = Hardtanh::new(-2.0, 2.0);
2207 let x = t(&[-5.0, -2.0, 0.0, 2.0, 5.0]);
2208 let y = m.forward(&x).unwrap();
2209 let d = y.data().unwrap();
2210 assert!((d[0] - (-2.0)).abs() < 1e-7);
2211 assert!((d[1] - (-2.0)).abs() < 1e-7);
2212 assert!((d[2] - 0.0).abs() < 1e-7);
2213 assert!((d[3] - 2.0).abs() < 1e-7);
2214 assert!((d[4] - 2.0).abs() < 1e-7);
2215 }
2216
2217 #[test]
2218 fn test_hardtanh_module_trait() {
2219 let mut m = Hardtanh::default();
2220 assert_zero_param_module::<Hardtanh, f64>(&mut m);
2221 }
2222
2223 #[test]
2228 fn test_log_sigmoid_forward() {
2229 let m = LogSigmoid::new();
2230 let x = t(&[0.0]);
2232 let y = m.forward(&x).unwrap();
2233 let d = y.data().unwrap();
2234 assert!(
2235 (d[0] - (-2.0_f64.ln())).abs() < 1e-6,
2236 "LogSigmoid(0) = {}, expected {}",
2237 d[0],
2238 -2.0_f64.ln()
2239 );
2240
2241 let x = t(&[-10.0, -1.0, 0.0, 1.0, 10.0]);
2243 let y = m.forward(&x).unwrap();
2244 let d = y.data().unwrap();
2245 assert!(
2246 d.iter().all(|&v| v <= 0.0),
2247 "All LogSigmoid values should be <= 0"
2248 );
2249
2250 assert!(
2252 d[4].abs() < 1e-4,
2253 "LogSigmoid(10) should be ~0, got {}",
2254 d[4]
2255 );
2256
2257 assert!(
2259 (d[0] - (-10.0)).abs() < 0.1,
2260 "LogSigmoid(-10) should be ~-10, got {}",
2261 d[0]
2262 );
2263 }
2264
2265 #[test]
2266 fn test_log_sigmoid_module_trait() {
2267 let mut m = LogSigmoid::new();
2268 assert_zero_param_module::<LogSigmoid, f64>(&mut m);
2269 }
2270
2271 #[test]
2276 fn test_softmin_forward_1d() {
2277 let m = Softmin::new(-1);
2278 let x = t(&[1.0, 2.0, 3.0]);
2279 let y = m.forward(&x).unwrap();
2280 let d = y.data().unwrap();
2281
2282 let total: f64 = d.iter().sum();
2284 assert!((total - 1.0).abs() < 1e-7, "Softmin sum = {}", total);
2285
2286 assert!(d[0] > d[1], "softmin(1) > softmin(2)");
2288 assert!(d[1] > d[2], "softmin(2) > softmin(3)");
2289 }
2290
2291 #[test]
2292 fn test_softmin_wrong_dim() {
2293 let m = Softmin::new(0);
2294 let x = t2d(&[1.0, 2.0, 3.0, 4.0], 2, 2);
2295 assert!(m.forward(&x).is_err());
2296 }
2297
2298 #[test]
2299 fn test_softmin_module_trait() {
2300 let mut m = Softmin::new(-1);
2301 assert_zero_param_module::<Softmin, f64>(&mut m);
2302 }
2303
2304 #[test]
2309 fn test_threshold_forward() {
2310 let m = Threshold::new(0.5, -1.0);
2311 let x = t(&[-1.0, 0.0, 0.5, 1.0, 2.0]);
2312 let y = m.forward(&x).unwrap();
2313 let d = y.data().unwrap();
2314 assert!((d[0] - (-1.0)).abs() < 1e-7, "Threshold(-1) = {}", d[0]);
2316 assert!((d[1] - (-1.0)).abs() < 1e-7, "Threshold(0) = {}", d[1]);
2317 assert!((d[2] - (-1.0)).abs() < 1e-7, "Threshold(0.5) = {}", d[2]);
2318 assert!((d[3] - 1.0).abs() < 1e-7, "Threshold(1) = {}", d[3]);
2320 assert!((d[4] - 2.0).abs() < 1e-7, "Threshold(2) = {}", d[4]);
2321 }
2322
2323 #[test]
2324 fn test_threshold_module_trait() {
2325 let mut m = Threshold::new(0.5, -1.0);
2326 assert_zero_param_module::<Threshold, f64>(&mut m);
2327 }
2328
2329 #[test]
2334 fn test_softshrink_forward() {
2335 let m = Softshrink::default(); let x = t(&[-2.0, -0.5, -0.3, 0.0, 0.3, 0.5, 2.0]);
2337 let y = m.forward(&x).unwrap();
2338 let d = y.data().unwrap();
2339 assert!((d[6] - 1.5).abs() < 1e-7, "Softshrink(2) = {}", d[6]);
2341 assert!((d[0] - (-1.5)).abs() < 1e-7, "Softshrink(-2) = {}", d[0]);
2343 assert!((d[2] - 0.0).abs() < 1e-7, "Softshrink(-0.3) = {}", d[2]);
2345 assert!((d[3] - 0.0).abs() < 1e-7, "Softshrink(0) = {}", d[3]);
2346 assert!((d[4] - 0.0).abs() < 1e-7, "Softshrink(0.3) = {}", d[4]);
2347 assert!((d[1] - 0.0).abs() < 1e-7, "Softshrink(-0.5) = {}", d[1]);
2349 assert!((d[5] - 0.0).abs() < 1e-7, "Softshrink(0.5) = {}", d[5]);
2350 }
2351
2352 #[test]
2353 fn test_softshrink_custom_lambda() {
2354 let m = Softshrink::new(1.0);
2355 let x = t(&[-2.0, -0.5, 0.5, 2.0]);
2356 let y = m.forward(&x).unwrap();
2357 let d = y.data().unwrap();
2358 assert!((d[0] - (-1.0)).abs() < 1e-7);
2359 assert!((d[1] - 0.0).abs() < 1e-7);
2360 assert!((d[2] - 0.0).abs() < 1e-7);
2361 assert!((d[3] - 1.0).abs() < 1e-7);
2362 }
2363
2364 #[test]
2365 fn test_softshrink_module_trait() {
2366 let mut m = Softshrink::default();
2367 assert_zero_param_module::<Softshrink, f64>(&mut m);
2368 }
2369
2370 #[test]
2375 fn test_hardshrink_forward() {
2376 let m = Hardshrink::default(); let x = t(&[-2.0, -0.5, -0.3, 0.0, 0.3, 0.5, 2.0]);
2378 let y = m.forward(&x).unwrap();
2379 let d = y.data().unwrap();
2380 assert!((d[0] - (-2.0)).abs() < 1e-7, "Hardshrink(-2) = {}", d[0]);
2382 assert!((d[6] - 2.0).abs() < 1e-7, "Hardshrink(2) = {}", d[6]);
2383 assert!((d[2] - 0.0).abs() < 1e-7, "Hardshrink(-0.3) = {}", d[2]);
2385 assert!((d[3] - 0.0).abs() < 1e-7, "Hardshrink(0) = {}", d[3]);
2386 assert!((d[4] - 0.0).abs() < 1e-7, "Hardshrink(0.3) = {}", d[4]);
2387 assert!((d[1] - 0.0).abs() < 1e-7, "Hardshrink(-0.5) = {}", d[1]);
2389 assert!((d[5] - 0.0).abs() < 1e-7, "Hardshrink(0.5) = {}", d[5]);
2390 }
2391
2392 #[test]
2393 fn test_hardshrink_module_trait() {
2394 let mut m = Hardshrink::default();
2395 assert_zero_param_module::<Hardshrink, f64>(&mut m);
2396 }
2397
2398 #[test]
2403 fn test_tanhshrink_forward() {
2404 let m = Tanhshrink::new();
2405 let x = t(&[0.0]);
2407 let y = m.forward(&x).unwrap();
2408 assert!(
2409 y.data().unwrap()[0].abs() < 1e-7,
2410 "Tanhshrink(0) should be 0"
2411 );
2412
2413 let x = t(&[10.0, -10.0]);
2415 let y = m.forward(&x).unwrap();
2416 let d = y.data().unwrap();
2417 assert!(
2418 (d[0] - 9.0).abs() < 0.01,
2419 "Tanhshrink(10) should be ~9, got {}",
2420 d[0]
2421 );
2422 assert!(
2423 (d[1] - (-9.0)).abs() < 0.01,
2424 "Tanhshrink(-10) should be ~-9, got {}",
2425 d[1]
2426 );
2427
2428 let x = t(&[1.0]);
2430 let y = m.forward(&x).unwrap();
2431 let expected = 1.0 - 1.0_f64.tanh();
2432 assert!(
2433 (y.data().unwrap()[0] - expected).abs() < 1e-7,
2434 "Tanhshrink(1) expected {}, got {}",
2435 expected,
2436 y.data().unwrap()[0]
2437 );
2438 }
2439
2440 #[test]
2441 fn test_tanhshrink_module_trait() {
2442 let mut m = Tanhshrink::new();
2443 assert_zero_param_module::<Tanhshrink, f64>(&mut m);
2444 }
2445
2446 #[test]
2451 fn test_softsign_forward() {
2452 let m = Softsign::new();
2453 let x = t(&[0.0]);
2455 let y = m.forward(&x).unwrap();
2456 assert!(y.data().unwrap()[0].abs() < 1e-7, "Softsign(0) should be 0");
2457
2458 let x = t(&[1.0]);
2460 let y = m.forward(&x).unwrap();
2461 assert!(
2462 (y.data().unwrap()[0] - 0.5).abs() < 1e-7,
2463 "Softsign(1) should be 0.5"
2464 );
2465
2466 let x = t(&[-1.0]);
2468 let y = m.forward(&x).unwrap();
2469 assert!(
2470 (y.data().unwrap()[0] - (-0.5)).abs() < 1e-7,
2471 "Softsign(-1) should be -0.5"
2472 );
2473
2474 let x = t(&[100.0, -100.0]);
2476 let y = m.forward(&x).unwrap();
2477 let d = y.data().unwrap();
2478 assert!(
2479 d[0] > 0.99 && d[0] < 1.0,
2480 "Softsign(100) should be ~1, got {}",
2481 d[0]
2482 );
2483 assert!(
2484 d[1] < -0.99 && d[1] > -1.0,
2485 "Softsign(-100) should be ~-1, got {}",
2486 d[1]
2487 );
2488 }
2489
2490 #[test]
2491 fn test_softsign_module_trait() {
2492 let mut m = Softsign::new();
2493 assert_zero_param_module::<Softsign, f64>(&mut m);
2494 }
2495
2496 #[test]
2501 #[allow(clippy::field_reassign_with_default)]
2502 fn test_rrelu_eval_forward() {
2503 let mut m = RReLU::default(); m.training = false;
2506 let mean_slope = (1.0 / 8.0 + 1.0 / 3.0) / 2.0;
2507
2508 let x = t(&[-2.0, -1.0, 0.0, 1.0, 2.0]);
2509 let y = m.forward(&x).unwrap();
2510 let d = y.data().unwrap();
2511
2512 assert!(
2513 (d[0] - (-2.0 * mean_slope)).abs() < 1e-7,
2514 "RReLU(-2,eval) = {}",
2515 d[0]
2516 );
2517 assert!(
2518 (d[1] - (-mean_slope)).abs() < 1e-7,
2519 "RReLU(-1,eval) = {}",
2520 d[1]
2521 );
2522 assert!((d[2] - 0.0).abs() < 1e-7, "RReLU(0,eval) = {}", d[2]);
2523 assert!((d[3] - 1.0).abs() < 1e-7, "RReLU(1,eval) = {}", d[3]);
2524 assert!((d[4] - 2.0).abs() < 1e-7, "RReLU(2,eval) = {}", d[4]);
2525 }
2526
2527 #[test]
2528 fn test_rrelu_training_positive_passthrough() {
2529 let m = RReLU::default();
2531 let x = t(&[0.0, 1.0, 5.0, 100.0]);
2532 let y = m.forward(&x).unwrap();
2533 let d = y.data().unwrap();
2534 assert!((d[0] - 0.0).abs() < 1e-7);
2535 assert!((d[1] - 1.0).abs() < 1e-7);
2536 assert!((d[2] - 5.0).abs() < 1e-7);
2537 assert!((d[3] - 100.0).abs() < 1e-7);
2538 }
2539
2540 #[test]
2541 fn test_rrelu_training_negative_bounded() {
2542 let m = RReLU::new(0.1, 0.5);
2544 let x = t(&[-1.0; 100]); let y = m.forward(&x).unwrap();
2546 let d = y.data().unwrap();
2547
2548 for (i, &val) in d.iter().enumerate() {
2549 assert!(
2551 (-0.5 - 1e-7..=-0.1 + 1e-7).contains(&val),
2552 "RReLU(-1, train)[{}] = {} not in [-0.5, -0.1]",
2553 i,
2554 val
2555 );
2556 }
2557
2558 let first = d[0];
2560 let has_variance = d.iter().any(|&v| (v - first).abs() > 1e-10);
2561 assert!(has_variance, "RReLU training should produce varying slopes");
2562 }
2563
2564 #[test]
2565 fn test_rrelu_module_trait() {
2566 let mut m = RReLU::default();
2567 assert_zero_param_module::<RReLU, f64>(&mut m);
2568 }
2569
2570 #[test]
2575 fn test_defaults() {
2576 let _relu = ReLU::default();
2577 let _gelu = GELU::default();
2578 let _silu = SiLU::default();
2579 let _sigmoid = Sigmoid::default();
2580 let _tanh = Tanh::default();
2581 let _softmax = Softmax::default();
2582 let _log_softmax = LogSoftmax::default();
2583
2584 let lrelu = LeakyReLU::default();
2585 assert!((lrelu.negative_slope - 0.01).abs() < f64::EPSILON);
2586
2587 let elu = ELU::default();
2588 assert!((elu.alpha - 1.0).abs() < f64::EPSILON);
2589
2590 let _mish = Mish::default();
2591
2592 let celu = CELU::default();
2593 assert!((celu.alpha - 1.0).abs() < f64::EPSILON);
2594
2595 let _selu = SELU::default();
2596 let _hard_sigmoid = HardSigmoid::default();
2597 let _hard_swish = HardSwish::default();
2598
2599 let softplus = Softplus::default();
2600 assert!((softplus.beta - 1.0).abs() < f64::EPSILON);
2601
2602 let _glu = GLU::default();
2603
2604 let _relu6 = ReLU6::default();
2606
2607 let hardtanh = Hardtanh::default();
2608 assert!((hardtanh.min_val - (-1.0)).abs() < f64::EPSILON);
2609 assert!((hardtanh.max_val - 1.0).abs() < f64::EPSILON);
2610
2611 let _log_sigmoid = LogSigmoid::default();
2612 let _softmin = Softmin::default();
2613
2614 let softshrink = Softshrink::default();
2615 assert!((softshrink.lambda - 0.5).abs() < f64::EPSILON);
2616
2617 let hardshrink = Hardshrink::default();
2618 assert!((hardshrink.lambda - 0.5).abs() < f64::EPSILON);
2619
2620 let _tanhshrink = Tanhshrink::default();
2621 let _softsign = Softsign::default();
2622
2623 let rrelu = RReLU::default();
2624 assert!((rrelu.lower - 1.0 / 8.0).abs() < f64::EPSILON);
2625 assert!((rrelu.upper - 1.0 / 3.0).abs() < f64::EPSILON);
2626 }
2627
2628 #[test]
2633 fn test_send_sync() {
2634 fn assert_send_sync<T: Send + Sync>() {}
2635 assert_send_sync::<ReLU>();
2636 assert_send_sync::<GELU>();
2637 assert_send_sync::<SiLU>();
2638 assert_send_sync::<Sigmoid>();
2639 assert_send_sync::<Tanh>();
2640 assert_send_sync::<Softmax>();
2641 assert_send_sync::<LogSoftmax>();
2642 assert_send_sync::<LeakyReLU>();
2643 assert_send_sync::<ELU>();
2644 assert_send_sync::<Mish>();
2645 assert_send_sync::<PReLU<f64>>();
2646 assert_send_sync::<CELU>();
2647 assert_send_sync::<SELU>();
2648 assert_send_sync::<HardSigmoid>();
2649 assert_send_sync::<HardSwish>();
2650 assert_send_sync::<Softplus>();
2651 assert_send_sync::<GLU>();
2652 assert_send_sync::<ReLU6>();
2654 assert_send_sync::<Hardtanh>();
2655 assert_send_sync::<LogSigmoid>();
2656 assert_send_sync::<Softmin>();
2657 assert_send_sync::<Threshold>();
2658 assert_send_sync::<Softshrink>();
2659 assert_send_sync::<Hardshrink>();
2660 assert_send_sync::<Tanhshrink>();
2661 assert_send_sync::<Softsign>();
2662 assert_send_sync::<RReLU>();
2663 }
2664
2665 fn t_grad(data: &[f64]) -> Tensor<f64> {
2671 Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], true).unwrap()
2672 }
2673
2674 fn t_scalar_grad(val: f64) -> Tensor<f64> {
2676 Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], true).unwrap()
2677 }
2678
2679 fn numerical_grad(f: impl Fn(f64) -> f64, x: f64) -> f64 {
2681 let h = 1e-5;
2682 (f(x + h) - f(x - h)) / (2.0 * h)
2683 }
2684
2685 #[test]
2688 fn test_softplus_backward_produces_grad() {
2689 let x = t_scalar_grad(1.0);
2690 let m = Softplus::new(1.0);
2691 let y = m.forward(&x).unwrap();
2692 ferrotorch_core::backward(&y).unwrap();
2693
2694 let grad = x.grad().unwrap();
2695 assert!(
2696 grad.is_some(),
2697 "Softplus backward should produce a gradient"
2698 );
2699 }
2700
2701 #[test]
2702 fn test_softplus_backward_at_zero() {
2703 let x = t_scalar_grad(0.0);
2704 let m = Softplus::new(1.0);
2705 let y = m.forward(&x).unwrap();
2706 ferrotorch_core::backward(&y).unwrap();
2707
2708 let grad = x.grad().unwrap().unwrap();
2709 assert!(
2711 (grad.item().unwrap() - 0.5).abs() < 1e-6,
2712 "Softplus grad at x=0: expected 0.5, got {}",
2713 grad.item().unwrap()
2714 );
2715 }
2716
2717 #[test]
2718 fn test_softplus_backward_matches_numerical() {
2719 for &val in &[-2.0, -0.5, 0.0, 1.0, 3.0] {
2720 let x = t_scalar_grad(val);
2721 let m = Softplus::new(1.0);
2722 let y = m.forward(&x).unwrap();
2723 ferrotorch_core::backward(&y).unwrap();
2724
2725 let grad = x.grad().unwrap().unwrap();
2726 let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
2727 assert!(
2728 (grad.item().unwrap() - expected).abs() < 1e-4,
2729 "Softplus grad at x={}: expected {}, got {}",
2730 val,
2731 expected,
2732 grad.item().unwrap()
2733 );
2734 }
2735 }
2736
2737 #[test]
2738 fn test_softplus_backward_custom_beta() {
2739 let val = 1.0;
2740 let beta = 2.0;
2741 let x = t_scalar_grad(val);
2742 let m = Softplus::new(beta);
2743 let y = m.forward(&x).unwrap();
2744 ferrotorch_core::backward(&y).unwrap();
2745
2746 let grad = x.grad().unwrap().unwrap();
2747 let expected = numerical_grad(|v| (1.0 + (beta * v).exp()).ln() / beta, val);
2748 assert!(
2749 (grad.item().unwrap() - expected).abs() < 1e-4,
2750 "Softplus grad at x={}, beta={}: expected {}, got {}",
2751 val,
2752 beta,
2753 expected,
2754 grad.item().unwrap()
2755 );
2756 }
2757
2758 #[test]
2759 fn test_softplus_backward_vector() {
2760 let x = t_grad(&[-2.0, -0.5, 0.0, 1.0, 3.0]);
2761 let m = Softplus::new(1.0);
2762 let y = m.forward(&x).unwrap();
2763 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
2765 ferrotorch_core::backward(&sum).unwrap();
2766
2767 let grad = x.grad().unwrap().unwrap();
2768 let grad_data = grad.data().unwrap();
2769
2770 for (i, &val) in [-2.0_f64, -0.5, 0.0, 1.0, 3.0].iter().enumerate() {
2771 let expected = numerical_grad(|v| (1.0 + v.exp()).ln(), val);
2772 assert!(
2773 (grad_data[i] - expected).abs() < 1e-4,
2774 "Softplus grad[{}] at x={}: expected {}, got {}",
2775 i,
2776 val,
2777 expected,
2778 grad_data[i]
2779 );
2780 }
2781 }
2782
2783 #[test]
2786 fn test_elu_backward_produces_grad() {
2787 let x = t_scalar_grad(-1.0);
2788 let m = ELU::new(1.0);
2789 let y = m.forward(&x).unwrap();
2790 ferrotorch_core::backward(&y).unwrap();
2791
2792 let grad = x.grad().unwrap();
2793 assert!(grad.is_some(), "ELU backward should produce a gradient");
2794 }
2795
2796 #[test]
2797 fn test_elu_backward_positive() {
2798 let x = t_scalar_grad(2.0);
2799 let m = ELU::new(1.0);
2800 let y = m.forward(&x).unwrap();
2801 ferrotorch_core::backward(&y).unwrap();
2802
2803 let grad = x.grad().unwrap().unwrap();
2804 assert!(
2806 (grad.item().unwrap() - 1.0).abs() < 1e-6,
2807 "ELU grad at x=2: expected 1.0, got {}",
2808 grad.item().unwrap()
2809 );
2810 }
2811
2812 #[test]
2813 fn test_elu_backward_matches_numerical() {
2814 let alpha = 1.0;
2815 for &val in &[-2.0, -1.0, -0.5, 0.5, 2.0] {
2816 let x = t_scalar_grad(val);
2817 let m = ELU::new(alpha);
2818 let y = m.forward(&x).unwrap();
2819 ferrotorch_core::backward(&y).unwrap();
2820
2821 let grad = x.grad().unwrap().unwrap();
2822 let expected =
2823 numerical_grad(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) }, val);
2824 assert!(
2825 (grad.item().unwrap() - expected).abs() < 1e-4,
2826 "ELU grad at x={}: expected {}, got {}",
2827 val,
2828 expected,
2829 grad.item().unwrap()
2830 );
2831 }
2832 }
2833
2834 #[test]
2835 fn test_elu_backward_custom_alpha() {
2836 let alpha = 2.0;
2837 let val = -0.5;
2838 let x = t_scalar_grad(val);
2839 let m = ELU::new(alpha);
2840 let y = m.forward(&x).unwrap();
2841 ferrotorch_core::backward(&y).unwrap();
2842
2843 let grad = x.grad().unwrap().unwrap();
2844 let expected = alpha * val.exp();
2846 assert!(
2847 (grad.item().unwrap() - expected).abs() < 1e-5,
2848 "ELU grad at x={}, alpha={}: expected {}, got {}",
2849 val,
2850 alpha,
2851 expected,
2852 grad.item().unwrap()
2853 );
2854 }
2855
2856 #[test]
2859 fn test_mish_backward_produces_grad() {
2860 let x = t_scalar_grad(1.0);
2861 let m = Mish::new();
2862 let y = m.forward(&x).unwrap();
2863 ferrotorch_core::backward(&y).unwrap();
2864
2865 let grad = x.grad().unwrap();
2866 assert!(grad.is_some(), "Mish backward should produce a gradient");
2867 }
2868
2869 #[test]
2870 fn test_mish_backward_matches_numerical() {
2871 let mish_fn = |v: f64| {
2872 let sp = (1.0 + v.exp()).ln();
2873 v * sp.tanh()
2874 };
2875
2876 for &val in &[-2.0, -1.0, 0.0, 0.5, 1.5, 3.0] {
2877 let x = t_scalar_grad(val);
2878 let m = Mish::new();
2879 let y = m.forward(&x).unwrap();
2880 ferrotorch_core::backward(&y).unwrap();
2881
2882 let grad = x.grad().unwrap().unwrap();
2883 let expected = numerical_grad(mish_fn, val);
2884 assert!(
2885 (grad.item().unwrap() - expected).abs() < 1e-4,
2886 "Mish grad at x={}: expected {}, got {}",
2887 val,
2888 expected,
2889 grad.item().unwrap()
2890 );
2891 }
2892 }
2893
2894 #[test]
2895 fn test_mish_backward_vector() {
2896 let x = t_grad(&[-1.0, 0.0, 1.0, 2.0]);
2897 let m = Mish::new();
2898 let y = m.forward(&x).unwrap();
2899 let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
2900 ferrotorch_core::backward(&sum).unwrap();
2901
2902 let grad = x.grad().unwrap().unwrap();
2903 let grad_data = grad.data().unwrap();
2904
2905 let mish_fn = |v: f64| {
2906 let sp = (1.0 + v.exp()).ln();
2907 v * sp.tanh()
2908 };
2909
2910 for (i, &val) in [-1.0_f64, 0.0, 1.0, 2.0].iter().enumerate() {
2911 let expected = numerical_grad(mish_fn, val);
2912 assert!(
2913 (grad_data[i] - expected).abs() < 1e-4,
2914 "Mish grad[{}] at x={}: expected {}, got {}",
2915 i,
2916 val,
2917 expected,
2918 grad_data[i]
2919 );
2920 }
2921 }
2922
2923 #[test]
2930 fn test_relu6_backward_matches_numerical() {
2931 let relu6_fn = |v: f64| v.clamp(0.0, 6.0);
2932
2933 for &val in &[-2.0, 0.5, 3.0, 5.5, 8.0] {
2934 let x = t_scalar_grad(val);
2935 let m = ReLU6::new();
2936 let y = m.forward(&x).unwrap();
2937 ferrotorch_core::backward(&y).unwrap();
2938
2939 let grad = x.grad().unwrap().unwrap();
2940 let expected = numerical_grad(relu6_fn, val);
2941 assert!(
2942 (grad.item().unwrap() - expected).abs() < 1e-4,
2943 "ReLU6 grad at x={}: expected {}, got {}",
2944 val,
2945 expected,
2946 grad.item().unwrap()
2947 );
2948 }
2949 }
2950
2951 #[test]
2954 fn test_hardtanh_backward_matches_numerical() {
2955 let hardtanh_fn = |v: f64| v.clamp(-1.0, 1.0);
2956
2957 for &val in &[-2.0, -0.5, 0.0, 0.5, 2.0] {
2958 let x = t_scalar_grad(val);
2959 let m = Hardtanh::default();
2960 let y = m.forward(&x).unwrap();
2961 ferrotorch_core::backward(&y).unwrap();
2962
2963 let grad = x.grad().unwrap().unwrap();
2964 let expected = numerical_grad(hardtanh_fn, val);
2965 assert!(
2966 (grad.item().unwrap() - expected).abs() < 1e-4,
2967 "Hardtanh grad at x={}: expected {}, got {}",
2968 val,
2969 expected,
2970 grad.item().unwrap()
2971 );
2972 }
2973 }
2974
2975 #[test]
2978 fn test_log_sigmoid_backward_matches_numerical() {
2979 let logsigmoid_fn = |v: f64| {
2980 -(1.0 + (-v).exp()).ln()
2982 };
2983
2984 for &val in &[-3.0, -1.0, 0.0, 1.0, 3.0] {
2985 let x = t_scalar_grad(val);
2986 let m = LogSigmoid::new();
2987 let y = m.forward(&x).unwrap();
2988 ferrotorch_core::backward(&y).unwrap();
2989
2990 let grad = x.grad().unwrap().unwrap();
2991 let expected = numerical_grad(logsigmoid_fn, val);
2992 assert!(
2993 (grad.item().unwrap() - expected).abs() < 1e-4,
2994 "LogSigmoid grad at x={}: expected {}, got {}",
2995 val,
2996 expected,
2997 grad.item().unwrap()
2998 );
2999 }
3000 }
3001
3002 #[test]
3005 fn test_tanhshrink_backward_matches_numerical() {
3006 let tanhshrink_fn = |v: f64| v - v.tanh();
3007
3008 for &val in &[-2.0, -0.5, 0.0, 0.5, 2.0] {
3009 let x = t_scalar_grad(val);
3010 let m = Tanhshrink::new();
3011 let y = m.forward(&x).unwrap();
3012 ferrotorch_core::backward(&y).unwrap();
3013
3014 let grad = x.grad().unwrap().unwrap();
3015 let expected = numerical_grad(tanhshrink_fn, val);
3016 assert!(
3017 (grad.item().unwrap() - expected).abs() < 1e-4,
3018 "Tanhshrink grad at x={}: expected {}, got {}",
3019 val,
3020 expected,
3021 grad.item().unwrap()
3022 );
3023 }
3024 }
3025
3026 #[test]
3031 fn test_state_dict_empty() {
3032 let m = ReLU::new();
3033 let sd = Module::<f64>::state_dict(&m);
3034 assert!(sd.is_empty());
3035 }
3036
3037 #[test]
3045 #[cfg(feature = "cuda")]
3046 #[ignore = "needs CUDA hardware; tracking #1451"]
3047 fn softmax2d_forward_gpu_matches_cpu() {
3048 use ferrotorch_core::Device;
3049 use ferrotorch_core::storage::TensorStorage;
3050 use ferrotorch_gpu::init_cuda_backend;
3051 init_cuda_backend().expect("CUDA init failed");
3052
3053 let (n, c, h, w) = (2usize, 5usize, 3usize, 4usize);
3055 let total = n * c * h * w;
3056 let data: Vec<f32> = (0..total).map(|k| ((k % 11) as f32) * 0.3 - 1.4).collect();
3057
3058 let sm = Softmax2d::new();
3059
3060 let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![n, c, h, w], false)
3061 .unwrap();
3062 let y_cpu = sm.forward(&x_cpu).unwrap();
3063 let cpu_vals = y_cpu.data().unwrap().to_vec();
3064
3065 let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
3066 let y_gpu = sm.forward(&x_gpu).unwrap();
3067 assert!(y_gpu.is_cuda(), "Softmax2d GPU output must stay on CUDA");
3068 let gpu_vals = y_gpu.data_vec().unwrap();
3069
3070 assert_eq!(gpu_vals.len(), cpu_vals.len());
3071 let mut max_abs = 0.0f32;
3072 for (g, c) in gpu_vals.iter().zip(cpu_vals.iter()) {
3073 max_abs = max_abs.max((g - c).abs());
3074 }
3075 assert!(max_abs < 1e-4, "Softmax2d GPU vs CPU max|Δ| = {max_abs}");
3076 }
3077}