1use std::any::Any;
18
19use axonml_autograd::no_grad::is_grad_enabled;
20use axonml_autograd::{GradFn, GradientFunction, Variable};
21use axonml_tensor::Tensor;
22
23use crate::module::Module;
24
25#[derive(Debug, Clone, Copy, PartialEq, Default)]
31pub enum Reduction {
32 None,
34 #[default]
36 Mean,
37 Sum,
39}
40
41#[derive(Debug, Clone, Copy)]
49pub struct MSELoss {
50 reduction: Reduction,
51}
52
53impl MSELoss {
54 pub fn new() -> Self {
56 Self {
57 reduction: Reduction::Mean,
58 }
59 }
60
61 pub fn with_reduction(reduction: Reduction) -> Self {
63 Self { reduction }
64 }
65
66 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
68 let diff = input.sub_var(target);
69 let squared = diff.pow(2.0);
70
71 match self.reduction {
72 Reduction::None => squared,
73 Reduction::Mean => squared.mean(),
74 Reduction::Sum => squared.sum(),
75 }
76 }
77}
78
79impl Default for MSELoss {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl Module for MSELoss {
86 fn forward(&self, input: &Variable) -> Variable {
87 input.clone()
90 }
91
92 fn name(&self) -> &'static str {
93 "MSELoss"
94 }
95}
96
97#[derive(Debug, Clone, Copy)]
105pub struct L1Loss {
106 reduction: Reduction,
107}
108
109impl L1Loss {
110 pub fn new() -> Self {
112 Self {
113 reduction: Reduction::Mean,
114 }
115 }
116
117 pub fn with_reduction(reduction: Reduction) -> Self {
119 Self { reduction }
120 }
121
122 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
124 let input_data = input.data();
125 let target_data = target.data();
126 let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
128 let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
130 let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
131 let abs_tensor = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
132
133 let requires_grad = (input.requires_grad() || target.requires_grad()) && is_grad_enabled();
134 let loss_var = if requires_grad {
135 let grad_fn = GradFn::new(L1LossBackward {
136 next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
137 diff_tensor,
138 });
139 Variable::from_operation(abs_tensor, grad_fn, true)
140 } else {
141 Variable::new(abs_tensor, false)
142 };
143
144 match self.reduction {
145 Reduction::None => loss_var,
146 Reduction::Mean => loss_var.mean(),
147 Reduction::Sum => loss_var.sum(),
148 }
149 }
150}
151
152impl Default for L1Loss {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[derive(Debug)]
169struct L1LossBackward {
170 next_fns: Vec<Option<GradFn>>,
171 diff_tensor: Tensor<f32>,
172}
173
174impl GradientFunction for L1LossBackward {
175 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
176 let eps_tensor = Tensor::full(self.diff_tensor.shape(), 1e-12);
179 let eps_on_device = if self.diff_tensor.device().is_gpu() {
180 eps_tensor.to_device(self.diff_tensor.device()).unwrap()
181 } else {
182 eps_tensor
183 };
184 let diff_sq = self
186 .diff_tensor
187 .mul(&self.diff_tensor)
188 .expect("tensor mul failed");
189 let diff_sq_eps = diff_sq.add(&eps_on_device).expect("tensor add failed");
190 let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
192 let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
193
194 let gi = sign_diff.mul(grad_output).unwrap();
196 let gt = gi.neg();
198 vec![Some(gi), Some(gt)]
199 }
200
201 fn name(&self) -> &'static str {
202 "L1LossBackward"
203 }
204
205 fn next_functions(&self) -> &[Option<GradFn>] {
206 &self.next_fns
207 }
208
209 fn as_any(&self) -> &dyn Any {
210 self
211 }
212}
213
214#[derive(Debug)]
224struct CrossEntropyBackward {
225 next_fns: Vec<Option<GradFn>>,
226 softmax_probs: Tensor<f32>,
229 targets: Tensor<f32>,
231 batch_size: usize,
232 num_classes: usize,
233}
234
235impl GradientFunction for CrossEntropyBackward {
236 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
237 let softmax_vec = self.softmax_probs.to_vec();
243 let target_vec = self.targets.to_vec();
244 let grad_vec = grad_output.to_vec();
245 let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
246
247 let is_scalar_grad = grad_vec.len() == 1;
249
250 for b in 0..self.batch_size {
251 let grad_scale = if is_scalar_grad {
252 grad_vec[0]
253 } else if b < grad_vec.len() {
254 grad_vec[b]
255 } else {
256 1.0 / self.batch_size as f32
257 };
258 let offset = b * self.num_classes;
259 let tc = target_vec[b] as usize;
260 for c in 0..self.num_classes {
261 let mut g = softmax_vec[offset + c];
262 if c == tc {
263 g -= 1.0;
264 }
265 grad_input[offset + c] = g * grad_scale;
266 }
267 }
268
269 let mut grad_tensor = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
270 .expect("tensor creation failed");
271 if self.softmax_probs.device().is_gpu() {
273 grad_tensor = grad_tensor.to_device(self.softmax_probs.device()).unwrap();
274 }
275 vec![Some(grad_tensor)]
276 }
277
278 fn name(&self) -> &'static str {
279 "CrossEntropyBackward"
280 }
281
282 fn next_functions(&self) -> &[Option<GradFn>] {
283 &self.next_fns
284 }
285
286 fn as_any(&self) -> &dyn Any {
287 self
288 }
289}
290
291#[derive(Debug, Clone, Copy)]
303pub struct CrossEntropyLoss {
304 reduction: Reduction,
305}
306
307impl CrossEntropyLoss {
308 pub fn new() -> Self {
310 Self {
311 reduction: Reduction::Mean,
312 }
313 }
314
315 pub fn with_reduction(reduction: Reduction) -> Self {
317 Self { reduction }
318 }
319
320 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
326 let input_data = input.data();
327 let target_data = target.data();
328 let shape = input_data.shape().to_vec();
329 let batch_size = shape[0];
330 let num_classes = shape[1];
331
332 #[cfg(feature = "cuda")]
334 if input_data.device().is_gpu() {
335 let targets_gpu = if target_data.device().is_gpu() {
337 target_data.clone()
338 } else {
339 target_data.to_device(input_data.device()).unwrap()
340 };
341
342 let (loss_tensor, softmax_tensor) = input_data.cross_entropy_fwd_cuda(&targets_gpu);
343
344 let loss_var = if input.requires_grad() {
345 let grad_fn = GradFn::new(CrossEntropyBackward {
346 next_fns: vec![input.grad_fn().cloned()],
347 softmax_probs: softmax_tensor,
348 targets: targets_gpu,
349 batch_size,
350 num_classes,
351 });
352 Variable::from_operation(loss_tensor, grad_fn, true)
353 } else {
354 Variable::new(loss_tensor, false)
355 };
356
357 return match self.reduction {
358 Reduction::None => loss_var,
359 Reduction::Mean => loss_var.mean(),
360 Reduction::Sum => loss_var.sum(),
361 };
362 }
363
364 let input_vec = input_data.to_vec();
366 let target_vec = target_data.to_vec();
367
368 let mut losses = vec![0.0f32; batch_size];
369 let mut softmax_probs_vec = vec![0.0f32; batch_size * num_classes];
370 let mut target_classes = vec![0usize; batch_size];
371
372 for b in 0..batch_size {
373 let offset = b * num_classes;
374
375 let max_val = (0..num_classes)
377 .map(|c| input_vec[offset + c])
378 .fold(f32::NEG_INFINITY, f32::max);
379
380 let mut sum_exp = 0.0f32;
381 for c in 0..num_classes {
382 let exp_val = (input_vec[offset + c] - max_val).exp();
383 softmax_probs_vec[offset + c] = exp_val;
384 sum_exp += exp_val;
385 }
386
387 for c in 0..num_classes {
388 softmax_probs_vec[offset + c] /= sum_exp;
389 }
390
391 let log_sum_exp = max_val + sum_exp.ln();
392
393 let tc = target_vec[b] as usize;
394 target_classes[b] = tc;
395 losses[b] = log_sum_exp - input_vec[offset + tc];
396 }
397
398 let loss_tensor = Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
399 let softmax_tensor = Tensor::from_vec(softmax_probs_vec, &[batch_size, num_classes])
400 .expect("tensor creation failed");
401 let targets_f32: Vec<f32> = target_classes.iter().map(|&tc| tc as f32).collect();
402 let targets_tensor =
403 Tensor::from_vec(targets_f32, &[batch_size]).expect("tensor creation failed");
404
405 let loss_var = if input.requires_grad() {
406 let grad_fn = GradFn::new(CrossEntropyBackward {
407 next_fns: vec![input.grad_fn().cloned()],
408 softmax_probs: softmax_tensor,
409 targets: targets_tensor,
410 batch_size,
411 num_classes,
412 });
413 Variable::from_operation(loss_tensor, grad_fn, true)
414 } else {
415 Variable::new(loss_tensor, false)
416 };
417
418 match self.reduction {
419 Reduction::None => loss_var,
420 Reduction::Mean => loss_var.mean(),
421 Reduction::Sum => loss_var.sum(),
422 }
423 }
424}
425
426impl Default for CrossEntropyLoss {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432#[derive(Debug, Clone, Copy)]
440pub struct NLLLoss {
441 reduction: Reduction,
442}
443
444impl NLLLoss {
445 pub fn new() -> Self {
447 Self {
448 reduction: Reduction::Mean,
449 }
450 }
451
452 pub fn with_reduction(reduction: Reduction) -> Self {
454 Self { reduction }
455 }
456
457 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
459 let input_data = input.data();
460 let target_data = target.data();
461 let shape = input_data.shape().to_vec();
462 let batch_size = shape[0];
463 let num_classes = shape[1];
464
465 let target_vec = target_data.to_vec();
468 let input_vec = input_data.to_vec();
469
470 let mut losses = vec![0.0f32; batch_size];
471 for b in 0..batch_size {
472 let tc = target_vec[b] as usize;
473 losses[b] = -input_vec[b * num_classes + tc];
474 }
475
476 let mut loss_tensor =
477 Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
478 if input_data.device().is_gpu() {
479 loss_tensor = loss_tensor.to_device(input_data.device()).unwrap();
480 }
481
482 let requires_grad = input.requires_grad() && is_grad_enabled();
483 let loss_var = if requires_grad {
484 let grad_fn = GradFn::new(NLLLossBackward {
485 next_fns: vec![input.grad_fn().cloned()],
486 target_tensor: target_data.clone(),
487 batch_size,
488 num_classes,
489 });
490 Variable::from_operation(loss_tensor, grad_fn, true)
491 } else {
492 Variable::new(loss_tensor, false)
493 };
494
495 match self.reduction {
496 Reduction::None => loss_var,
497 Reduction::Mean => loss_var.mean(),
498 Reduction::Sum => loss_var.sum(),
499 }
500 }
501}
502
503impl Default for NLLLoss {
504 fn default() -> Self {
505 Self::new()
506 }
507}
508
509#[derive(Debug)]
521struct NLLLossBackward {
522 next_fns: Vec<Option<GradFn>>,
523 target_tensor: Tensor<f32>,
524 batch_size: usize,
525 num_classes: usize,
526}
527
528impl GradientFunction for NLLLossBackward {
529 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
530 let grad_out_vec = grad_output.to_vec();
531 let target_vec = self.target_tensor.to_vec();
532 let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
533
534 for b in 0..self.batch_size {
535 let g = if grad_out_vec.len() == 1 {
536 grad_out_vec[0]
537 } else {
538 grad_out_vec[b]
539 };
540 let tc = target_vec[b] as usize;
541 grad_input[b * self.num_classes + tc] = -g;
542 }
543
544 let mut gi = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
545 .expect("tensor creation failed");
546 if grad_output.device().is_gpu() {
547 gi = gi.to_device(grad_output.device()).unwrap();
548 }
549 vec![Some(gi)]
550 }
551
552 fn name(&self) -> &'static str {
553 "NLLLossBackward"
554 }
555
556 fn next_functions(&self) -> &[Option<GradFn>] {
557 &self.next_fns
558 }
559
560 fn as_any(&self) -> &dyn Any {
561 self
562 }
563}
564
565#[derive(Debug, Clone, Copy)]
573pub struct BCELoss {
574 reduction: Reduction,
575}
576
577impl BCELoss {
578 pub fn new() -> Self {
580 Self {
581 reduction: Reduction::Mean,
582 }
583 }
584
585 pub fn with_reduction(reduction: Reduction) -> Self {
587 Self { reduction }
588 }
589
590 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
592 let input_data = input.data();
593 let target_data = target.data();
594
595 let eps = 1e-7f32;
597 let p_clamped = axonml_tensor::ops::clamp(&input_data, eps, 1.0 - eps);
598
599 let ln_p = p_clamped.ln();
601 let one_minus_p = p_clamped.neg().add_scalar(1.0);
602 let ln_one_minus_p = one_minus_p.ln();
603 let one_minus_t = target_data.neg().add_scalar(1.0);
604
605 let term1 = target_data.mul(&ln_p).expect("tensor mul failed");
607 let term2 = one_minus_t.mul(&ln_one_minus_p).expect("tensor mul failed");
609 let loss_tensor = term1.add(&term2).expect("tensor add failed").neg();
611
612 let requires_grad = input.requires_grad() && is_grad_enabled();
613 let loss_var = if requires_grad {
614 let grad_fn = GradFn::new(BCELossBackward {
615 next_fns: vec![input.grad_fn().cloned()],
616 input_tensor: input_data,
617 target_tensor: target_data,
618 });
619 Variable::from_operation(loss_tensor, grad_fn, true)
620 } else {
621 Variable::new(loss_tensor, false)
622 };
623
624 match self.reduction {
625 Reduction::None => loss_var,
626 Reduction::Mean => loss_var.mean(),
627 Reduction::Sum => loss_var.sum(),
628 }
629 }
630}
631
632impl Default for BCELoss {
633 fn default() -> Self {
634 Self::new()
635 }
636}
637
638#[derive(Debug)]
648struct BCELossBackward {
649 next_fns: Vec<Option<GradFn>>,
650 input_tensor: Tensor<f32>,
651 target_tensor: Tensor<f32>,
652}
653
654impl GradientFunction for BCELossBackward {
655 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
656 let eps = 1e-7f32;
657 let p_clamped = axonml_tensor::ops::clamp(&self.input_tensor, eps, 1.0 - eps);
659 let p_minus_y = p_clamped
661 .sub(&self.target_tensor)
662 .expect("tensor sub failed");
663 let one_minus_p = p_clamped.neg().add_scalar(1.0);
665 let denom = p_clamped.mul(&one_minus_p).expect("tensor mul failed");
666 let ratio = p_minus_y.div(&denom).unwrap();
668 let grad_tensor = grad_output.mul(&ratio).expect("tensor mul failed");
669 vec![Some(grad_tensor)]
670 }
671
672 fn name(&self) -> &'static str {
673 "BCELossBackward"
674 }
675
676 fn next_functions(&self) -> &[Option<GradFn>] {
677 &self.next_fns
678 }
679
680 fn as_any(&self) -> &dyn Any {
681 self
682 }
683}
684
685#[derive(Debug)]
695struct BCEWithLogitsBackward {
696 next_fns: Vec<Option<GradFn>>,
697 input_tensor: Tensor<f32>,
698 target_tensor: Tensor<f32>,
699}
700
701impl GradientFunction for BCEWithLogitsBackward {
702 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
703 let sig = self.input_tensor.sigmoid();
705 let sig_minus_t = sig.sub(&self.target_tensor).expect("tensor sub failed");
706 let grad_tensor = grad_output.mul(&sig_minus_t).expect("tensor mul failed");
707 vec![Some(grad_tensor)]
708 }
709
710 fn name(&self) -> &'static str {
711 "BCEWithLogitsBackward"
712 }
713
714 fn next_functions(&self) -> &[Option<GradFn>] {
715 &self.next_fns
716 }
717
718 fn as_any(&self) -> &dyn Any {
719 self
720 }
721}
722
723#[derive(Debug, Clone, Copy)]
731pub struct BCEWithLogitsLoss {
732 reduction: Reduction,
733}
734
735impl BCEWithLogitsLoss {
736 pub fn new() -> Self {
738 Self {
739 reduction: Reduction::Mean,
740 }
741 }
742
743 pub fn with_reduction(reduction: Reduction) -> Self {
745 Self { reduction }
746 }
747
748 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
750 let input_data = input.data();
751 let target_data = target.data();
752
753 let relu_x = axonml_tensor::ops::clamp_min(&input_data, 0.0);
756 let x_times_t = input_data.mul(&target_data).expect("tensor mul failed");
758 let neg_x = input_data.neg();
760 let relu_neg_x = axonml_tensor::ops::clamp_min(&neg_x, 0.0);
761 let abs_x = relu_x.add(&relu_neg_x).expect("tensor add failed");
762 let exp_neg_abs = abs_x.neg().exp();
764 let log_term = exp_neg_abs.add_scalar(1.0).ln();
766 let loss_tensor = relu_x
768 .sub(&x_times_t)
769 .expect("tensor sub failed")
770 .add(&log_term)
771 .expect("tensor add failed");
772
773 let loss_var = if input.requires_grad() {
774 let grad_fn = GradFn::new(BCEWithLogitsBackward {
775 next_fns: vec![input.grad_fn().cloned()],
776 input_tensor: input_data,
777 target_tensor: target_data,
778 });
779 Variable::from_operation(loss_tensor, grad_fn, true)
780 } else {
781 Variable::new(loss_tensor, false)
782 };
783
784 match self.reduction {
785 Reduction::None => loss_var,
786 Reduction::Mean => loss_var.mean(),
787 Reduction::Sum => loss_var.sum(),
788 }
789 }
790}
791
792impl Default for BCEWithLogitsLoss {
793 fn default() -> Self {
794 Self::new()
795 }
796}
797
798#[derive(Debug)]
809struct SmoothL1Backward {
810 next_fns: Vec<Option<GradFn>>,
811 diff_tensor: Tensor<f32>,
812 beta: f32,
813 shape: Vec<usize>,
814}
815
816impl GradientFunction for SmoothL1Backward {
817 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
818 let eps = 1e-12f32;
820 let diff_sq = self
821 .diff_tensor
822 .mul(&self.diff_tensor)
823 .expect("tensor mul failed");
824 let diff_sq_eps = diff_sq.add_scalar(eps);
825 let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
826
827 let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
829
830 let grad_l2 = self.diff_tensor.mul_scalar(1.0 / self.beta); let grad_l1 = sign_diff; let abs_vec = abs_diff.to_vec();
846 let beta = self.beta;
847 let mask_vec: Vec<f32> = abs_vec
848 .iter()
849 .map(|&a| if a < beta { 1.0 } else { 0.0 })
850 .collect();
851 let mut mask = Tensor::from_vec(mask_vec, &self.shape).expect("tensor creation failed");
852 if self.diff_tensor.device().is_gpu() {
853 mask = mask.to_device(self.diff_tensor.device()).unwrap();
854 }
855 let inv_mask = mask.neg().add_scalar(1.0);
856
857 let blended = mask
859 .mul(&grad_l2)
860 .unwrap()
861 .add(&inv_mask.mul(&grad_l1).expect("tensor add failed"))
862 .unwrap();
863
864 let gi = blended.mul(grad_output).unwrap();
866 let gt = gi.neg();
867 vec![Some(gi), Some(gt)]
868 }
869
870 fn name(&self) -> &'static str {
871 "SmoothL1Backward"
872 }
873
874 fn next_functions(&self) -> &[Option<GradFn>] {
875 &self.next_fns
876 }
877
878 fn as_any(&self) -> &dyn Any {
879 self
880 }
881}
882
883#[derive(Debug, Clone, Copy)]
891pub struct SmoothL1Loss {
892 reduction: Reduction,
893 beta: f32,
894}
895
896impl SmoothL1Loss {
897 pub fn new() -> Self {
899 Self {
900 reduction: Reduction::Mean,
901 beta: 1.0,
902 }
903 }
904
905 pub fn with_beta(beta: f32) -> Self {
907 Self {
908 reduction: Reduction::Mean,
909 beta,
910 }
911 }
912
913 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
915 let input_data = input.data();
916 let target_data = target.data();
917 let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
918 let shape = diff_tensor.shape().to_vec();
919
920 let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
922 let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
923 let abs_diff = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
924
925 let diff_sq = diff_tensor.mul(&diff_tensor).expect("tensor mul failed");
927 let l2_loss = diff_sq.mul_scalar(0.5 / self.beta);
928
929 let l1_loss = abs_diff.add_scalar(-0.5 * self.beta);
931
932 let abs_vec = abs_diff.to_vec();
934 let beta = self.beta;
935 let mask_vec: Vec<f32> = abs_vec
936 .iter()
937 .map(|&a| if a < beta { 1.0 } else { 0.0 })
938 .collect();
939 let mut mask = Tensor::from_vec(mask_vec, &shape).expect("tensor creation failed");
940 if diff_tensor.device().is_gpu() {
941 mask = mask.to_device(diff_tensor.device()).unwrap();
942 }
943 let inv_mask = mask.neg().add_scalar(1.0);
944
945 let loss_tensor = mask
947 .mul(&l2_loss)
948 .unwrap()
949 .add(&inv_mask.mul(&l1_loss).expect("tensor add failed"))
950 .unwrap();
951
952 let loss_var = if input.requires_grad() || target.requires_grad() {
953 let grad_fn = GradFn::new(SmoothL1Backward {
954 next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
955 diff_tensor,
956 beta: self.beta,
957 shape,
958 });
959 Variable::from_operation(loss_tensor, grad_fn, true)
960 } else {
961 Variable::new(loss_tensor, false)
962 };
963
964 match self.reduction {
965 Reduction::None => loss_var,
966 Reduction::Mean => loss_var.mean(),
967 Reduction::Sum => loss_var.sum(),
968 }
969 }
970}
971
972impl Default for SmoothL1Loss {
973 fn default() -> Self {
974 Self::new()
975 }
976}
977
978#[cfg(test)]
983mod tests {
984 use super::*;
985
986 #[test]
987 fn test_mse_loss() {
988 let loss_fn = MSELoss::new();
989 let input = Variable::new(
990 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
991 false,
992 );
993 let target = Variable::new(
994 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
995 false,
996 );
997 let loss = loss_fn.compute(&input, &target);
998 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
999 }
1000
1001 #[test]
1002 fn test_mse_loss_nonzero() {
1003 let loss_fn = MSELoss::new();
1004 let input = Variable::new(
1005 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1006 false,
1007 );
1008 let target = Variable::new(
1009 Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).expect("tensor creation failed"),
1010 false,
1011 );
1012 let loss = loss_fn.compute(&input, &target);
1013 assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
1015 }
1016
1017 #[test]
1018 fn test_cross_entropy_loss() {
1019 let loss_fn = CrossEntropyLoss::new();
1020 let input = Variable::new(
1021 Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3])
1022 .expect("tensor creation failed"),
1023 false,
1024 );
1025 let target = Variable::new(
1026 Tensor::from_vec(vec![2.0, 0.0], &[2]).expect("tensor creation failed"),
1027 false,
1028 );
1029 let loss = loss_fn.compute(&input, &target);
1030 assert!(loss.data().to_vec()[0] > 0.0);
1031 }
1032
1033 #[test]
1034 fn test_bce_loss() {
1035 let loss_fn = BCELoss::new();
1036 let input = Variable::new(
1037 Tensor::from_vec(vec![0.5, 0.5], &[2]).expect("tensor creation failed"),
1038 false,
1039 );
1040 let target = Variable::new(
1041 Tensor::from_vec(vec![1.0, 0.0], &[2]).expect("tensor creation failed"),
1042 false,
1043 );
1044 let loss = loss_fn.compute(&input, &target);
1045 assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1047 }
1048
1049 #[test]
1050 fn test_cross_entropy_gradient_flow() {
1051 use axonml_autograd::backward;
1052
1053 let input = Variable::new(
1055 Tensor::from_vec(vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3], &[2, 3])
1056 .expect("tensor creation failed"),
1057 true,
1058 );
1059 let target = Variable::new(
1060 Tensor::from_vec(vec![0.0, 1.0], &[2]).expect("tensor creation failed"),
1061 false,
1062 );
1063
1064 let loss_fn = CrossEntropyLoss::new();
1065 let loss = loss_fn.compute(&input, &target);
1066
1067 let loss_val = loss.data().to_vec()[0];
1069 assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
1070
1071 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1073 backward(&loss, &ones);
1074
1075 let grad = input
1077 .grad()
1078 .expect("Input should have gradient after backward");
1079 let grad_vec = grad.to_vec();
1080
1081 let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1083 assert!(
1084 grad_norm > 1e-10,
1085 "Gradient should be non-zero, got norm {}",
1086 grad_norm
1087 );
1088
1089 assert_eq!(grad.shape(), &[2, 3]);
1091
1092 assert!(
1095 grad_vec[0] < 0.0,
1096 "Gradient for correct class should be negative"
1097 );
1098 assert!(
1100 grad_vec[4] < 0.0,
1101 "Gradient for correct class should be negative"
1102 );
1103
1104 assert!(
1106 grad_vec[1] > 0.0,
1107 "Gradient for wrong class should be positive"
1108 );
1109 assert!(
1110 grad_vec[2] > 0.0,
1111 "Gradient for wrong class should be positive"
1112 );
1113 }
1114
1115 #[test]
1116 fn test_cross_entropy_perfect_prediction() {
1117 let loss_fn = CrossEntropyLoss::new();
1119 let input = Variable::new(
1120 Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).expect("tensor creation failed"),
1121 false,
1122 );
1123 let target = Variable::new(
1124 Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1125 false,
1126 );
1127 let loss = loss_fn.compute(&input, &target);
1128 assert!(
1129 loss.data().to_vec()[0] < 0.001,
1130 "Perfect prediction should have near-zero loss"
1131 );
1132 }
1133
1134 #[test]
1135 fn test_cross_entropy_uniform_prediction() {
1136 let loss_fn = CrossEntropyLoss::new();
1138 let num_classes = 16;
1139 let input = Variable::new(
1140 Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes])
1141 .expect("tensor creation failed"),
1142 false,
1143 );
1144 let target = Variable::new(
1145 Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1146 false,
1147 );
1148 let loss = loss_fn.compute(&input, &target);
1149 let expected = (num_classes as f32).ln(); let actual = loss.data().to_vec()[0];
1151 assert!(
1152 (actual - expected).abs() < 0.01,
1153 "Uniform logits should give ln(C)={}, got {}",
1154 expected,
1155 actual,
1156 );
1157 }
1158
1159 #[test]
1160 fn test_bce_with_logits_gradient_flow() {
1161 use axonml_autograd::backward;
1162
1163 let input = Variable::new(
1164 Tensor::from_vec(vec![0.5, -0.5, 1.0, -1.0], &[4]).expect("tensor creation failed"),
1165 true,
1166 );
1167 let target = Variable::new(
1168 Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).expect("tensor creation failed"),
1169 false,
1170 );
1171
1172 let loss_fn = BCEWithLogitsLoss::new();
1173 let loss = loss_fn.compute(&input, &target);
1174 assert!(loss.data().to_vec()[0] > 0.0);
1175
1176 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1177 backward(&loss, &ones);
1178
1179 let grad = input.grad().expect("Input should have gradient");
1180 let grad_vec = grad.to_vec();
1181 assert_eq!(grad_vec.len(), 4);
1182
1183 assert!(grad_vec[0] < 0.0);
1185 assert!(grad_vec[1] > 0.0);
1187 }
1188
1189 #[test]
1190 fn test_smooth_l1_gradient_flow() {
1191 use axonml_autograd::backward;
1192
1193 let input = Variable::new(
1194 Tensor::from_vec(vec![1.0, 2.0, 5.0], &[3]).expect("tensor creation failed"),
1195 true,
1196 );
1197 let target = Variable::new(
1198 Tensor::from_vec(vec![1.5, 1.5, 1.5], &[3]).expect("tensor creation failed"),
1199 false,
1200 );
1201
1202 let loss_fn = SmoothL1Loss::new();
1203 let loss = loss_fn.compute(&input, &target);
1204 assert!(loss.data().to_vec()[0] > 0.0);
1205
1206 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1207 backward(&loss, &ones);
1208
1209 let grad = input.grad().expect("Input should have gradient");
1210 let grad_vec = grad.to_vec();
1211 assert_eq!(grad_vec.len(), 3);
1212 let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1214 assert!(grad_norm > 1e-10);
1215 }
1216
1217 #[test]
1222 fn test_mse_loss_gradient_correctness() {
1223 use axonml_autograd::backward;
1224
1225 let input = Variable::new(Tensor::from_vec(vec![3.0, 1.0], &[2]).unwrap(), true);
1227 let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1228
1229 let loss = MSELoss::new().compute(&input, &target);
1230 assert!(
1232 (loss.data().to_vec()[0] - 2.0).abs() < 1e-5,
1233 "MSE should be 2.0"
1234 );
1235
1236 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1237 backward(&loss, &ones);
1238
1239 let grad = input.grad().expect("Should have gradient");
1240 let gv = grad.to_vec();
1241 assert!(
1243 (gv[0] - 2.0).abs() < 0.1,
1244 "Grad[0] should be ~2.0, got {}",
1245 gv[0]
1246 );
1247 assert!(gv[1].abs() < 0.1, "Grad[1] should be ~0.0, got {}", gv[1]);
1248 }
1249
1250 #[test]
1251 fn test_mse_loss_reduction_sum() {
1252 let input = Variable::new(Tensor::from_vec(vec![2.0, 4.0], &[2]).unwrap(), false);
1253 let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1254 let loss = MSELoss::with_reduction(Reduction::Sum).compute(&input, &target);
1255 assert!((loss.data().to_vec()[0] - 10.0).abs() < 1e-5);
1257 }
1258
1259 #[test]
1264 fn test_l1_loss_basic() {
1265 let input = Variable::new(Tensor::from_vec(vec![1.0, 5.0, 3.0], &[3]).unwrap(), false);
1266 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 4.0], &[3]).unwrap(), false);
1267 let loss = L1Loss::new().compute(&input, &target);
1268 assert!((loss.data().to_vec()[0] - 4.0 / 3.0).abs() < 1e-4);
1270 }
1271
1272 #[test]
1273 fn test_l1_loss_zero() {
1274 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1275 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1276 let loss = L1Loss::new().compute(&input, &target);
1277 assert!(
1278 loss.data().to_vec()[0].abs() < 1e-6,
1279 "Identical inputs should give 0 loss"
1280 );
1281 }
1282
1283 #[test]
1288 fn test_bce_loss_perfect_prediction() {
1289 let loss_fn = BCELoss::new();
1290 let input = Variable::new(Tensor::from_vec(vec![0.999, 0.001], &[2]).unwrap(), false);
1292 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1293 let loss = loss_fn.compute(&input, &target);
1294 assert!(
1295 loss.data().to_vec()[0] < 0.01,
1296 "Perfect prediction should have near-zero loss"
1297 );
1298 }
1299
1300 #[test]
1301 fn test_bce_loss_worst_prediction() {
1302 let loss_fn = BCELoss::new();
1303 let input = Variable::new(Tensor::from_vec(vec![0.001, 0.999], &[2]).unwrap(), false);
1305 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1306 let loss = loss_fn.compute(&input, &target);
1307 assert!(
1309 loss.data().to_vec()[0] > 3.0,
1310 "Worst prediction should have high loss"
1311 );
1312 }
1313
1314 #[test]
1319 fn test_bce_with_logits_numerical_stability() {
1320 let loss_fn = BCEWithLogitsLoss::new();
1321 let input = Variable::new(
1323 Tensor::from_vec(vec![100.0, -100.0, 50.0, -50.0], &[4]).unwrap(),
1324 false,
1325 );
1326 let target = Variable::new(
1327 Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
1328 false,
1329 );
1330 let loss = loss_fn.compute(&input, &target);
1331 let val = loss.data().to_vec()[0];
1332 assert!(
1333 val.is_finite(),
1334 "Loss should be finite for large logits, got {}",
1335 val
1336 );
1337 assert!(val >= 0.0, "BCE loss should be non-negative");
1338 }
1339
1340 #[test]
1341 fn test_bce_with_logits_zero_logits() {
1342 let loss_fn = BCEWithLogitsLoss::new();
1343 let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1345 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1346 let loss = loss_fn.compute(&input, &target);
1347 assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1349 }
1350
1351 #[test]
1352 fn test_bce_with_logits_reduction_none() {
1353 let loss_fn = BCEWithLogitsLoss::with_reduction(Reduction::None);
1354 let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap(), false);
1355 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0, 1.0], &[3]).unwrap(), false);
1356 let loss = loss_fn.compute(&input, &target);
1357 assert_eq!(loss.shape().len(), 1);
1359 assert_eq!(loss.shape()[0], 3);
1360 }
1361
1362 #[test]
1367 fn test_smooth_l1_small_error() {
1368 let loss_fn = SmoothL1Loss::new();
1370 let input = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
1371 let target = Variable::new(Tensor::from_vec(vec![1.3], &[1]).unwrap(), false);
1372 let loss = loss_fn.compute(&input, &target);
1373 assert!((loss.data().to_vec()[0] - 0.045).abs() < 0.01);
1375 }
1376
1377 #[test]
1378 fn test_smooth_l1_large_error() {
1379 let loss_fn = SmoothL1Loss::new();
1381 let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1382 let target = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), false);
1383 let loss = loss_fn.compute(&input, &target);
1384 assert!((loss.data().to_vec()[0] - 4.5).abs() < 0.1);
1386 }
1387
1388 #[test]
1393 fn test_cross_entropy_batch_independence() {
1394 let loss_fn = CrossEntropyLoss::new();
1395
1396 let input1 = Variable::new(
1398 Tensor::from_vec(vec![2.0, 1.0, 0.1], &[1, 3]).unwrap(),
1399 false,
1400 );
1401 let target1 = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1402 let loss1 = loss_fn.compute(&input1, &target1).data().to_vec()[0];
1403
1404 let input2 = Variable::new(
1406 Tensor::from_vec(vec![2.0, 1.0, 0.1, 2.0, 1.0, 0.1], &[2, 3]).unwrap(),
1407 false,
1408 );
1409 let target2 = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1410 let loss2 = loss_fn.compute(&input2, &target2).data().to_vec()[0];
1411
1412 assert!(
1414 (loss1 - loss2).abs() < 1e-5,
1415 "Duplicated batch should give same loss: {} vs {}",
1416 loss1,
1417 loss2
1418 );
1419 }
1420
1421 #[test]
1422 fn test_cross_entropy_high_class_count() {
1423 let n_classes = 100;
1425 let mut logits = vec![0.0f32; n_classes];
1426 logits[42] = 5.0; let loss_fn = CrossEntropyLoss::new();
1429 let input = Variable::new(Tensor::from_vec(logits, &[1, n_classes]).unwrap(), false);
1430 let target = Variable::new(Tensor::from_vec(vec![42.0], &[1]).unwrap(), false);
1431 let loss = loss_fn.compute(&input, &target);
1432 let val = loss.data().to_vec()[0];
1433 assert!(val.is_finite(), "Should handle 100 classes");
1434 assert!(val < 1.0, "Correct class should have low loss, got {}", val);
1435 }
1436}