1use std::any::Any;
27
28use axonml_autograd::no_grad::is_grad_enabled;
29use axonml_autograd::{GradFn, GradientFunction, Variable};
30use axonml_tensor::Tensor;
31
32use crate::module::Module;
33
34#[derive(Debug, Clone, Copy, PartialEq, Default)]
40pub enum Reduction {
41 None,
43 #[default]
45 Mean,
46 Sum,
48}
49
50#[derive(Debug, Clone, Copy)]
58pub struct MSELoss {
59 reduction: Reduction,
60}
61
62impl MSELoss {
63 pub fn new() -> Self {
65 Self {
66 reduction: Reduction::Mean,
67 }
68 }
69
70 pub fn with_reduction(reduction: Reduction) -> Self {
72 Self { reduction }
73 }
74
75 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
77 let diff = input.sub_var(target);
78 let squared = diff.pow(2.0);
79
80 match self.reduction {
81 Reduction::None => squared,
82 Reduction::Mean => squared.mean(),
83 Reduction::Sum => squared.sum(),
84 }
85 }
86}
87
88impl Default for MSELoss {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl Module for MSELoss {
95 fn forward(&self, input: &Variable) -> Variable {
96 input.clone()
99 }
100
101 fn name(&self) -> &'static str {
102 "MSELoss"
103 }
104}
105
106#[derive(Debug, Clone, Copy)]
114pub struct L1Loss {
115 reduction: Reduction,
116}
117
118impl L1Loss {
119 pub fn new() -> Self {
121 Self {
122 reduction: Reduction::Mean,
123 }
124 }
125
126 pub fn with_reduction(reduction: Reduction) -> Self {
128 Self { reduction }
129 }
130
131 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
133 let input_data = input.data();
134 let target_data = target.data();
135 let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
137 let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
139 let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
140 let abs_tensor = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
141
142 let requires_grad = (input.requires_grad() || target.requires_grad()) && is_grad_enabled();
143 let loss_var = if requires_grad {
144 let grad_fn = GradFn::new(L1LossBackward {
145 next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
146 diff_tensor,
147 });
148 Variable::from_operation(abs_tensor, grad_fn, true)
149 } else {
150 Variable::new(abs_tensor, false)
151 };
152
153 match self.reduction {
154 Reduction::None => loss_var,
155 Reduction::Mean => loss_var.mean(),
156 Reduction::Sum => loss_var.sum(),
157 }
158 }
159}
160
161impl Default for L1Loss {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167#[derive(Debug)]
178struct L1LossBackward {
179 next_fns: Vec<Option<GradFn>>,
180 diff_tensor: Tensor<f32>,
181}
182
183impl GradientFunction for L1LossBackward {
184 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
185 let eps_tensor = Tensor::full(self.diff_tensor.shape(), 1e-12);
188 let eps_on_device = if self.diff_tensor.device().is_gpu() {
189 eps_tensor.to_device(self.diff_tensor.device()).unwrap()
190 } else {
191 eps_tensor
192 };
193 let diff_sq = self
195 .diff_tensor
196 .mul(&self.diff_tensor)
197 .expect("tensor mul failed");
198 let diff_sq_eps = diff_sq.add(&eps_on_device).expect("tensor add failed");
199 let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
201 let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
202
203 let gi = sign_diff.mul(grad_output).unwrap();
205 let gt = gi.neg();
207 vec![Some(gi), Some(gt)]
208 }
209
210 fn name(&self) -> &'static str {
211 "L1LossBackward"
212 }
213
214 fn next_functions(&self) -> &[Option<GradFn>] {
215 &self.next_fns
216 }
217
218 fn as_any(&self) -> &dyn Any {
219 self
220 }
221}
222
223#[derive(Debug)]
233struct CrossEntropyBackward {
234 next_fns: Vec<Option<GradFn>>,
235 softmax_probs: Tensor<f32>,
238 targets: Tensor<f32>,
240 batch_size: usize,
241 num_classes: usize,
242}
243
244impl GradientFunction for CrossEntropyBackward {
245 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
246 let softmax_vec = self.softmax_probs.to_vec();
252 let target_vec = self.targets.to_vec();
253 let grad_vec = grad_output.to_vec();
254 let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
255
256 let is_scalar_grad = grad_vec.len() == 1;
258
259 for b in 0..self.batch_size {
260 let grad_scale = if is_scalar_grad {
261 grad_vec[0]
262 } else if b < grad_vec.len() {
263 grad_vec[b]
264 } else {
265 1.0 / self.batch_size as f32
266 };
267 let offset = b * self.num_classes;
268 let tc = target_vec[b] as usize;
269 for c in 0..self.num_classes {
270 let mut g = softmax_vec[offset + c];
271 if c == tc {
272 g -= 1.0;
273 }
274 grad_input[offset + c] = g * grad_scale;
275 }
276 }
277
278 let mut grad_tensor = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
279 .expect("tensor creation failed");
280 if self.softmax_probs.device().is_gpu() {
282 grad_tensor = grad_tensor.to_device(self.softmax_probs.device()).unwrap();
283 }
284 vec![Some(grad_tensor)]
285 }
286
287 fn name(&self) -> &'static str {
288 "CrossEntropyBackward"
289 }
290
291 fn next_functions(&self) -> &[Option<GradFn>] {
292 &self.next_fns
293 }
294
295 fn as_any(&self) -> &dyn Any {
296 self
297 }
298}
299
300#[derive(Debug, Clone, Copy)]
312pub struct CrossEntropyLoss {
313 reduction: Reduction,
314}
315
316impl CrossEntropyLoss {
317 pub fn new() -> Self {
319 Self {
320 reduction: Reduction::Mean,
321 }
322 }
323
324 pub fn with_reduction(reduction: Reduction) -> Self {
326 Self { reduction }
327 }
328
329 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
335 let input_data = input.data();
336 let target_data = target.data();
337 let shape = input_data.shape().to_vec();
338 let batch_size = shape[0];
339 let num_classes = shape[1];
340
341 #[cfg(feature = "cuda")]
343 if input_data.device().is_gpu() {
344 let targets_gpu = if target_data.device().is_gpu() {
346 target_data.clone()
347 } else {
348 target_data.to_device(input_data.device()).unwrap()
349 };
350
351 let (loss_tensor, softmax_tensor) = input_data.cross_entropy_fwd_cuda(&targets_gpu);
352
353 let loss_var = if input.requires_grad() {
354 let grad_fn = GradFn::new(CrossEntropyBackward {
355 next_fns: vec![input.grad_fn().cloned()],
356 softmax_probs: softmax_tensor,
357 targets: targets_gpu,
358 batch_size,
359 num_classes,
360 });
361 Variable::from_operation(loss_tensor, grad_fn, true)
362 } else {
363 Variable::new(loss_tensor, false)
364 };
365
366 return match self.reduction {
367 Reduction::None => loss_var,
368 Reduction::Mean => loss_var.mean(),
369 Reduction::Sum => loss_var.sum(),
370 };
371 }
372
373 let input_vec = input_data.to_vec();
375 let target_vec = target_data.to_vec();
376
377 let mut losses = vec![0.0f32; batch_size];
378 let mut softmax_probs_vec = vec![0.0f32; batch_size * num_classes];
379 let mut target_classes = vec![0usize; batch_size];
380
381 for b in 0..batch_size {
382 let offset = b * num_classes;
383
384 let max_val = (0..num_classes)
386 .map(|c| input_vec[offset + c])
387 .fold(f32::NEG_INFINITY, f32::max);
388
389 let mut sum_exp = 0.0f32;
390 for c in 0..num_classes {
391 let exp_val = (input_vec[offset + c] - max_val).exp();
392 softmax_probs_vec[offset + c] = exp_val;
393 sum_exp += exp_val;
394 }
395
396 for c in 0..num_classes {
397 softmax_probs_vec[offset + c] /= sum_exp;
398 }
399
400 let log_sum_exp = max_val + sum_exp.ln();
401
402 let tc = target_vec[b] as usize;
403 target_classes[b] = tc;
404 losses[b] = log_sum_exp - input_vec[offset + tc];
405 }
406
407 let loss_tensor = Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
408 let softmax_tensor = Tensor::from_vec(softmax_probs_vec, &[batch_size, num_classes])
409 .expect("tensor creation failed");
410 let targets_f32: Vec<f32> = target_classes.iter().map(|&tc| tc as f32).collect();
411 let targets_tensor =
412 Tensor::from_vec(targets_f32, &[batch_size]).expect("tensor creation failed");
413
414 let loss_var = if input.requires_grad() {
415 let grad_fn = GradFn::new(CrossEntropyBackward {
416 next_fns: vec![input.grad_fn().cloned()],
417 softmax_probs: softmax_tensor,
418 targets: targets_tensor,
419 batch_size,
420 num_classes,
421 });
422 Variable::from_operation(loss_tensor, grad_fn, true)
423 } else {
424 Variable::new(loss_tensor, false)
425 };
426
427 match self.reduction {
428 Reduction::None => loss_var,
429 Reduction::Mean => loss_var.mean(),
430 Reduction::Sum => loss_var.sum(),
431 }
432 }
433}
434
435impl Default for CrossEntropyLoss {
436 fn default() -> Self {
437 Self::new()
438 }
439}
440
441#[derive(Debug, Clone, Copy)]
449pub struct NLLLoss {
450 reduction: Reduction,
451}
452
453impl NLLLoss {
454 pub fn new() -> Self {
456 Self {
457 reduction: Reduction::Mean,
458 }
459 }
460
461 pub fn with_reduction(reduction: Reduction) -> Self {
463 Self { reduction }
464 }
465
466 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
468 let input_data = input.data();
469 let target_data = target.data();
470 let shape = input_data.shape().to_vec();
471 let batch_size = shape[0];
472 let num_classes = shape[1];
473
474 let target_vec = target_data.to_vec();
477 let input_vec = input_data.to_vec();
478
479 let mut losses = vec![0.0f32; batch_size];
480 for b in 0..batch_size {
481 let tc = target_vec[b] as usize;
482 losses[b] = -input_vec[b * num_classes + tc];
483 }
484
485 let mut loss_tensor =
486 Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
487 if input_data.device().is_gpu() {
488 loss_tensor = loss_tensor.to_device(input_data.device()).unwrap();
489 }
490
491 let requires_grad = input.requires_grad() && is_grad_enabled();
492 let loss_var = if requires_grad {
493 let grad_fn = GradFn::new(NLLLossBackward {
494 next_fns: vec![input.grad_fn().cloned()],
495 target_tensor: target_data.clone(),
496 batch_size,
497 num_classes,
498 });
499 Variable::from_operation(loss_tensor, grad_fn, true)
500 } else {
501 Variable::new(loss_tensor, false)
502 };
503
504 match self.reduction {
505 Reduction::None => loss_var,
506 Reduction::Mean => loss_var.mean(),
507 Reduction::Sum => loss_var.sum(),
508 }
509 }
510}
511
512impl Default for NLLLoss {
513 fn default() -> Self {
514 Self::new()
515 }
516}
517
518#[derive(Debug)]
530struct NLLLossBackward {
531 next_fns: Vec<Option<GradFn>>,
532 target_tensor: Tensor<f32>,
533 batch_size: usize,
534 num_classes: usize,
535}
536
537impl GradientFunction for NLLLossBackward {
538 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
539 let grad_out_vec = grad_output.to_vec();
540 let target_vec = self.target_tensor.to_vec();
541 let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
542
543 for b in 0..self.batch_size {
544 let g = if grad_out_vec.len() == 1 {
545 grad_out_vec[0]
546 } else {
547 grad_out_vec[b]
548 };
549 let tc = target_vec[b] as usize;
550 grad_input[b * self.num_classes + tc] = -g;
551 }
552
553 let mut gi = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
554 .expect("tensor creation failed");
555 if grad_output.device().is_gpu() {
556 gi = gi.to_device(grad_output.device()).unwrap();
557 }
558 vec![Some(gi)]
559 }
560
561 fn name(&self) -> &'static str {
562 "NLLLossBackward"
563 }
564
565 fn next_functions(&self) -> &[Option<GradFn>] {
566 &self.next_fns
567 }
568
569 fn as_any(&self) -> &dyn Any {
570 self
571 }
572}
573
574#[derive(Debug, Clone, Copy)]
582pub struct BCELoss {
583 reduction: Reduction,
584}
585
586impl BCELoss {
587 pub fn new() -> Self {
589 Self {
590 reduction: Reduction::Mean,
591 }
592 }
593
594 pub fn with_reduction(reduction: Reduction) -> Self {
596 Self { reduction }
597 }
598
599 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
601 let input_data = input.data();
602 let target_data = target.data();
603
604 let eps = 1e-7f32;
606 let p_clamped = axonml_tensor::ops::clamp(&input_data, eps, 1.0 - eps);
607
608 let ln_p = p_clamped.ln();
610 let one_minus_p = p_clamped.neg().add_scalar(1.0);
611 let ln_one_minus_p = one_minus_p.ln();
612 let one_minus_t = target_data.neg().add_scalar(1.0);
613
614 let term1 = target_data.mul(&ln_p).expect("tensor mul failed");
616 let term2 = one_minus_t.mul(&ln_one_minus_p).expect("tensor mul failed");
618 let loss_tensor = term1.add(&term2).expect("tensor add failed").neg();
620
621 let requires_grad = input.requires_grad() && is_grad_enabled();
622 let loss_var = if requires_grad {
623 let grad_fn = GradFn::new(BCELossBackward {
624 next_fns: vec![input.grad_fn().cloned()],
625 input_tensor: input_data,
626 target_tensor: target_data,
627 });
628 Variable::from_operation(loss_tensor, grad_fn, true)
629 } else {
630 Variable::new(loss_tensor, false)
631 };
632
633 match self.reduction {
634 Reduction::None => loss_var,
635 Reduction::Mean => loss_var.mean(),
636 Reduction::Sum => loss_var.sum(),
637 }
638 }
639}
640
641impl Default for BCELoss {
642 fn default() -> Self {
643 Self::new()
644 }
645}
646
647#[derive(Debug)]
657struct BCELossBackward {
658 next_fns: Vec<Option<GradFn>>,
659 input_tensor: Tensor<f32>,
660 target_tensor: Tensor<f32>,
661}
662
663impl GradientFunction for BCELossBackward {
664 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
665 let eps = 1e-7f32;
666 let p_clamped = axonml_tensor::ops::clamp(&self.input_tensor, eps, 1.0 - eps);
668 let p_minus_y = p_clamped
670 .sub(&self.target_tensor)
671 .expect("tensor sub failed");
672 let one_minus_p = p_clamped.neg().add_scalar(1.0);
674 let denom = p_clamped.mul(&one_minus_p).expect("tensor mul failed");
675 let ratio = p_minus_y.div(&denom).unwrap();
677 let grad_tensor = grad_output.mul(&ratio).expect("tensor mul failed");
678 vec![Some(grad_tensor)]
679 }
680
681 fn name(&self) -> &'static str {
682 "BCELossBackward"
683 }
684
685 fn next_functions(&self) -> &[Option<GradFn>] {
686 &self.next_fns
687 }
688
689 fn as_any(&self) -> &dyn Any {
690 self
691 }
692}
693
694#[derive(Debug)]
704struct BCEWithLogitsBackward {
705 next_fns: Vec<Option<GradFn>>,
706 input_tensor: Tensor<f32>,
707 target_tensor: Tensor<f32>,
708}
709
710impl GradientFunction for BCEWithLogitsBackward {
711 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
712 let sig = self.input_tensor.sigmoid();
714 let sig_minus_t = sig.sub(&self.target_tensor).expect("tensor sub failed");
715 let grad_tensor = grad_output.mul(&sig_minus_t).expect("tensor mul failed");
716 vec![Some(grad_tensor)]
717 }
718
719 fn name(&self) -> &'static str {
720 "BCEWithLogitsBackward"
721 }
722
723 fn next_functions(&self) -> &[Option<GradFn>] {
724 &self.next_fns
725 }
726
727 fn as_any(&self) -> &dyn Any {
728 self
729 }
730}
731
732#[derive(Debug, Clone, Copy)]
740pub struct BCEWithLogitsLoss {
741 reduction: Reduction,
742}
743
744impl BCEWithLogitsLoss {
745 pub fn new() -> Self {
747 Self {
748 reduction: Reduction::Mean,
749 }
750 }
751
752 pub fn with_reduction(reduction: Reduction) -> Self {
754 Self { reduction }
755 }
756
757 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
759 let input_data = input.data();
760 let target_data = target.data();
761
762 let relu_x = axonml_tensor::ops::clamp_min(&input_data, 0.0);
765 let x_times_t = input_data.mul(&target_data).expect("tensor mul failed");
767 let neg_x = input_data.neg();
769 let relu_neg_x = axonml_tensor::ops::clamp_min(&neg_x, 0.0);
770 let abs_x = relu_x.add(&relu_neg_x).expect("tensor add failed");
771 let exp_neg_abs = abs_x.neg().exp();
773 let log_term = exp_neg_abs.add_scalar(1.0).ln();
775 let loss_tensor = relu_x
777 .sub(&x_times_t)
778 .expect("tensor sub failed")
779 .add(&log_term)
780 .expect("tensor add failed");
781
782 let loss_var = if input.requires_grad() {
783 let grad_fn = GradFn::new(BCEWithLogitsBackward {
784 next_fns: vec![input.grad_fn().cloned()],
785 input_tensor: input_data,
786 target_tensor: target_data,
787 });
788 Variable::from_operation(loss_tensor, grad_fn, true)
789 } else {
790 Variable::new(loss_tensor, false)
791 };
792
793 match self.reduction {
794 Reduction::None => loss_var,
795 Reduction::Mean => loss_var.mean(),
796 Reduction::Sum => loss_var.sum(),
797 }
798 }
799}
800
801impl Default for BCEWithLogitsLoss {
802 fn default() -> Self {
803 Self::new()
804 }
805}
806
807#[derive(Debug)]
818struct SmoothL1Backward {
819 next_fns: Vec<Option<GradFn>>,
820 diff_tensor: Tensor<f32>,
821 beta: f32,
822 shape: Vec<usize>,
823}
824
825impl GradientFunction for SmoothL1Backward {
826 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
827 let eps = 1e-12f32;
829 let diff_sq = self
830 .diff_tensor
831 .mul(&self.diff_tensor)
832 .expect("tensor mul failed");
833 let diff_sq_eps = diff_sq.add_scalar(eps);
834 let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
835
836 let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
838
839 let grad_l2 = self.diff_tensor.mul_scalar(1.0 / self.beta); let grad_l1 = sign_diff; let abs_vec = abs_diff.to_vec();
855 let beta = self.beta;
856 let mask_vec: Vec<f32> = abs_vec
857 .iter()
858 .map(|&a| if a < beta { 1.0 } else { 0.0 })
859 .collect();
860 let mut mask = Tensor::from_vec(mask_vec, &self.shape).expect("tensor creation failed");
861 if self.diff_tensor.device().is_gpu() {
862 mask = mask.to_device(self.diff_tensor.device()).unwrap();
863 }
864 let inv_mask = mask.neg().add_scalar(1.0);
865
866 let blended = mask
868 .mul(&grad_l2)
869 .unwrap()
870 .add(&inv_mask.mul(&grad_l1).expect("tensor add failed"))
871 .unwrap();
872
873 let gi = blended.mul(grad_output).unwrap();
875 let gt = gi.neg();
876 vec![Some(gi), Some(gt)]
877 }
878
879 fn name(&self) -> &'static str {
880 "SmoothL1Backward"
881 }
882
883 fn next_functions(&self) -> &[Option<GradFn>] {
884 &self.next_fns
885 }
886
887 fn as_any(&self) -> &dyn Any {
888 self
889 }
890}
891
892#[derive(Debug, Clone, Copy)]
900pub struct SmoothL1Loss {
901 reduction: Reduction,
902 beta: f32,
903}
904
905impl SmoothL1Loss {
906 pub fn new() -> Self {
908 Self {
909 reduction: Reduction::Mean,
910 beta: 1.0,
911 }
912 }
913
914 pub fn with_beta(beta: f32) -> Self {
916 Self {
917 reduction: Reduction::Mean,
918 beta,
919 }
920 }
921
922 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
924 let input_data = input.data();
925 let target_data = target.data();
926 let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
927 let shape = diff_tensor.shape().to_vec();
928
929 let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
931 let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
932 let abs_diff = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
933
934 let diff_sq = diff_tensor.mul(&diff_tensor).expect("tensor mul failed");
936 let l2_loss = diff_sq.mul_scalar(0.5 / self.beta);
937
938 let l1_loss = abs_diff.add_scalar(-0.5 * self.beta);
940
941 let abs_vec = abs_diff.to_vec();
943 let beta = self.beta;
944 let mask_vec: Vec<f32> = abs_vec
945 .iter()
946 .map(|&a| if a < beta { 1.0 } else { 0.0 })
947 .collect();
948 let mut mask = Tensor::from_vec(mask_vec, &shape).expect("tensor creation failed");
949 if diff_tensor.device().is_gpu() {
950 mask = mask.to_device(diff_tensor.device()).unwrap();
951 }
952 let inv_mask = mask.neg().add_scalar(1.0);
953
954 let loss_tensor = mask
956 .mul(&l2_loss)
957 .unwrap()
958 .add(&inv_mask.mul(&l1_loss).expect("tensor add failed"))
959 .unwrap();
960
961 let loss_var = if input.requires_grad() || target.requires_grad() {
962 let grad_fn = GradFn::new(SmoothL1Backward {
963 next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
964 diff_tensor,
965 beta: self.beta,
966 shape,
967 });
968 Variable::from_operation(loss_tensor, grad_fn, true)
969 } else {
970 Variable::new(loss_tensor, false)
971 };
972
973 match self.reduction {
974 Reduction::None => loss_var,
975 Reduction::Mean => loss_var.mean(),
976 Reduction::Sum => loss_var.sum(),
977 }
978 }
979}
980
981impl Default for SmoothL1Loss {
982 fn default() -> Self {
983 Self::new()
984 }
985}
986
987#[cfg(test)]
992mod tests {
993 use super::*;
994
995 #[test]
996 fn test_mse_loss() {
997 let loss_fn = MSELoss::new();
998 let input = Variable::new(
999 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1000 false,
1001 );
1002 let target = Variable::new(
1003 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1004 false,
1005 );
1006 let loss = loss_fn.compute(&input, &target);
1007 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
1008 }
1009
1010 #[test]
1011 fn test_mse_loss_nonzero() {
1012 let loss_fn = MSELoss::new();
1013 let input = Variable::new(
1014 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
1015 false,
1016 );
1017 let target = Variable::new(
1018 Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).expect("tensor creation failed"),
1019 false,
1020 );
1021 let loss = loss_fn.compute(&input, &target);
1022 assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
1024 }
1025
1026 #[test]
1027 fn test_cross_entropy_loss() {
1028 let loss_fn = CrossEntropyLoss::new();
1029 let input = Variable::new(
1030 Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3])
1031 .expect("tensor creation failed"),
1032 false,
1033 );
1034 let target = Variable::new(
1035 Tensor::from_vec(vec![2.0, 0.0], &[2]).expect("tensor creation failed"),
1036 false,
1037 );
1038 let loss = loss_fn.compute(&input, &target);
1039 assert!(loss.data().to_vec()[0] > 0.0);
1040 }
1041
1042 #[test]
1043 fn test_bce_loss() {
1044 let loss_fn = BCELoss::new();
1045 let input = Variable::new(
1046 Tensor::from_vec(vec![0.5, 0.5], &[2]).expect("tensor creation failed"),
1047 false,
1048 );
1049 let target = Variable::new(
1050 Tensor::from_vec(vec![1.0, 0.0], &[2]).expect("tensor creation failed"),
1051 false,
1052 );
1053 let loss = loss_fn.compute(&input, &target);
1054 assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1056 }
1057
1058 #[test]
1059 fn test_cross_entropy_gradient_flow() {
1060 use axonml_autograd::backward;
1061
1062 let input = Variable::new(
1064 Tensor::from_vec(vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3], &[2, 3])
1065 .expect("tensor creation failed"),
1066 true,
1067 );
1068 let target = Variable::new(
1069 Tensor::from_vec(vec![0.0, 1.0], &[2]).expect("tensor creation failed"),
1070 false,
1071 );
1072
1073 let loss_fn = CrossEntropyLoss::new();
1074 let loss = loss_fn.compute(&input, &target);
1075
1076 let loss_val = loss.data().to_vec()[0];
1078 assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
1079
1080 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1082 backward(&loss, &ones);
1083
1084 let grad = input
1086 .grad()
1087 .expect("Input should have gradient after backward");
1088 let grad_vec = grad.to_vec();
1089
1090 let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1092 assert!(
1093 grad_norm > 1e-10,
1094 "Gradient should be non-zero, got norm {}",
1095 grad_norm
1096 );
1097
1098 assert_eq!(grad.shape(), &[2, 3]);
1100
1101 assert!(
1104 grad_vec[0] < 0.0,
1105 "Gradient for correct class should be negative"
1106 );
1107 assert!(
1109 grad_vec[4] < 0.0,
1110 "Gradient for correct class should be negative"
1111 );
1112
1113 assert!(
1115 grad_vec[1] > 0.0,
1116 "Gradient for wrong class should be positive"
1117 );
1118 assert!(
1119 grad_vec[2] > 0.0,
1120 "Gradient for wrong class should be positive"
1121 );
1122 }
1123
1124 #[test]
1125 fn test_cross_entropy_perfect_prediction() {
1126 let loss_fn = CrossEntropyLoss::new();
1128 let input = Variable::new(
1129 Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).expect("tensor creation failed"),
1130 false,
1131 );
1132 let target = Variable::new(
1133 Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1134 false,
1135 );
1136 let loss = loss_fn.compute(&input, &target);
1137 assert!(
1138 loss.data().to_vec()[0] < 0.001,
1139 "Perfect prediction should have near-zero loss"
1140 );
1141 }
1142
1143 #[test]
1144 fn test_cross_entropy_uniform_prediction() {
1145 let loss_fn = CrossEntropyLoss::new();
1147 let num_classes = 16;
1148 let input = Variable::new(
1149 Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes])
1150 .expect("tensor creation failed"),
1151 false,
1152 );
1153 let target = Variable::new(
1154 Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
1155 false,
1156 );
1157 let loss = loss_fn.compute(&input, &target);
1158 let expected = (num_classes as f32).ln(); let actual = loss.data().to_vec()[0];
1160 assert!(
1161 (actual - expected).abs() < 0.01,
1162 "Uniform logits should give ln(C)={}, got {}",
1163 expected,
1164 actual,
1165 );
1166 }
1167
1168 #[test]
1169 fn test_bce_with_logits_gradient_flow() {
1170 use axonml_autograd::backward;
1171
1172 let input = Variable::new(
1173 Tensor::from_vec(vec![0.5, -0.5, 1.0, -1.0], &[4]).expect("tensor creation failed"),
1174 true,
1175 );
1176 let target = Variable::new(
1177 Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).expect("tensor creation failed"),
1178 false,
1179 );
1180
1181 let loss_fn = BCEWithLogitsLoss::new();
1182 let loss = loss_fn.compute(&input, &target);
1183 assert!(loss.data().to_vec()[0] > 0.0);
1184
1185 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1186 backward(&loss, &ones);
1187
1188 let grad = input.grad().expect("Input should have gradient");
1189 let grad_vec = grad.to_vec();
1190 assert_eq!(grad_vec.len(), 4);
1191
1192 assert!(grad_vec[0] < 0.0);
1194 assert!(grad_vec[1] > 0.0);
1196 }
1197
1198 #[test]
1199 fn test_smooth_l1_gradient_flow() {
1200 use axonml_autograd::backward;
1201
1202 let input = Variable::new(
1203 Tensor::from_vec(vec![1.0, 2.0, 5.0], &[3]).expect("tensor creation failed"),
1204 true,
1205 );
1206 let target = Variable::new(
1207 Tensor::from_vec(vec![1.5, 1.5, 1.5], &[3]).expect("tensor creation failed"),
1208 false,
1209 );
1210
1211 let loss_fn = SmoothL1Loss::new();
1212 let loss = loss_fn.compute(&input, &target);
1213 assert!(loss.data().to_vec()[0] > 0.0);
1214
1215 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1216 backward(&loss, &ones);
1217
1218 let grad = input.grad().expect("Input should have gradient");
1219 let grad_vec = grad.to_vec();
1220 assert_eq!(grad_vec.len(), 3);
1221 let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
1223 assert!(grad_norm > 1e-10);
1224 }
1225
1226 #[test]
1231 fn test_mse_loss_gradient_correctness() {
1232 use axonml_autograd::backward;
1233
1234 let input = Variable::new(Tensor::from_vec(vec![3.0, 1.0], &[2]).unwrap(), true);
1236 let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1237
1238 let loss = MSELoss::new().compute(&input, &target);
1239 assert!(
1241 (loss.data().to_vec()[0] - 2.0).abs() < 1e-5,
1242 "MSE should be 2.0"
1243 );
1244
1245 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
1246 backward(&loss, &ones);
1247
1248 let grad = input.grad().expect("Should have gradient");
1249 let gv = grad.to_vec();
1250 assert!(
1252 (gv[0] - 2.0).abs() < 0.1,
1253 "Grad[0] should be ~2.0, got {}",
1254 gv[0]
1255 );
1256 assert!(gv[1].abs() < 0.1, "Grad[1] should be ~0.0, got {}", gv[1]);
1257 }
1258
1259 #[test]
1260 fn test_mse_loss_reduction_sum() {
1261 let input = Variable::new(Tensor::from_vec(vec![2.0, 4.0], &[2]).unwrap(), false);
1262 let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
1263 let loss = MSELoss::with_reduction(Reduction::Sum).compute(&input, &target);
1264 assert!((loss.data().to_vec()[0] - 10.0).abs() < 1e-5);
1266 }
1267
1268 #[test]
1273 fn test_l1_loss_basic() {
1274 let input = Variable::new(Tensor::from_vec(vec![1.0, 5.0, 3.0], &[3]).unwrap(), false);
1275 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 4.0], &[3]).unwrap(), false);
1276 let loss = L1Loss::new().compute(&input, &target);
1277 assert!((loss.data().to_vec()[0] - 4.0 / 3.0).abs() < 1e-4);
1279 }
1280
1281 #[test]
1282 fn test_l1_loss_zero() {
1283 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1284 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
1285 let loss = L1Loss::new().compute(&input, &target);
1286 assert!(
1287 loss.data().to_vec()[0].abs() < 1e-6,
1288 "Identical inputs should give 0 loss"
1289 );
1290 }
1291
1292 #[test]
1297 fn test_bce_loss_perfect_prediction() {
1298 let loss_fn = BCELoss::new();
1299 let input = Variable::new(Tensor::from_vec(vec![0.999, 0.001], &[2]).unwrap(), false);
1301 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1302 let loss = loss_fn.compute(&input, &target);
1303 assert!(
1304 loss.data().to_vec()[0] < 0.01,
1305 "Perfect prediction should have near-zero loss"
1306 );
1307 }
1308
1309 #[test]
1310 fn test_bce_loss_worst_prediction() {
1311 let loss_fn = BCELoss::new();
1312 let input = Variable::new(Tensor::from_vec(vec![0.001, 0.999], &[2]).unwrap(), false);
1314 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1315 let loss = loss_fn.compute(&input, &target);
1316 assert!(
1318 loss.data().to_vec()[0] > 3.0,
1319 "Worst prediction should have high loss"
1320 );
1321 }
1322
1323 #[test]
1328 fn test_bce_with_logits_numerical_stability() {
1329 let loss_fn = BCEWithLogitsLoss::new();
1330 let input = Variable::new(
1332 Tensor::from_vec(vec![100.0, -100.0, 50.0, -50.0], &[4]).unwrap(),
1333 false,
1334 );
1335 let target = Variable::new(
1336 Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
1337 false,
1338 );
1339 let loss = loss_fn.compute(&input, &target);
1340 let val = loss.data().to_vec()[0];
1341 assert!(
1342 val.is_finite(),
1343 "Loss should be finite for large logits, got {}",
1344 val
1345 );
1346 assert!(val >= 0.0, "BCE loss should be non-negative");
1347 }
1348
1349 #[test]
1350 fn test_bce_with_logits_zero_logits() {
1351 let loss_fn = BCEWithLogitsLoss::new();
1352 let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1354 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
1355 let loss = loss_fn.compute(&input, &target);
1356 assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
1358 }
1359
1360 #[test]
1361 fn test_bce_with_logits_reduction_none() {
1362 let loss_fn = BCEWithLogitsLoss::with_reduction(Reduction::None);
1363 let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap(), false);
1364 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0, 1.0], &[3]).unwrap(), false);
1365 let loss = loss_fn.compute(&input, &target);
1366 assert_eq!(loss.shape().len(), 1);
1368 assert_eq!(loss.shape()[0], 3);
1369 }
1370
1371 #[test]
1376 fn test_smooth_l1_small_error() {
1377 let loss_fn = SmoothL1Loss::new();
1379 let input = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
1380 let target = Variable::new(Tensor::from_vec(vec![1.3], &[1]).unwrap(), false);
1381 let loss = loss_fn.compute(&input, &target);
1382 assert!((loss.data().to_vec()[0] - 0.045).abs() < 0.01);
1384 }
1385
1386 #[test]
1387 fn test_smooth_l1_large_error() {
1388 let loss_fn = SmoothL1Loss::new();
1390 let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1391 let target = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), false);
1392 let loss = loss_fn.compute(&input, &target);
1393 assert!((loss.data().to_vec()[0] - 4.5).abs() < 0.1);
1395 }
1396
1397 #[test]
1402 fn test_cross_entropy_batch_independence() {
1403 let loss_fn = CrossEntropyLoss::new();
1404
1405 let input1 = Variable::new(
1407 Tensor::from_vec(vec![2.0, 1.0, 0.1], &[1, 3]).unwrap(),
1408 false,
1409 );
1410 let target1 = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
1411 let loss1 = loss_fn.compute(&input1, &target1).data().to_vec()[0];
1412
1413 let input2 = Variable::new(
1415 Tensor::from_vec(vec![2.0, 1.0, 0.1, 2.0, 1.0, 0.1], &[2, 3]).unwrap(),
1416 false,
1417 );
1418 let target2 = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
1419 let loss2 = loss_fn.compute(&input2, &target2).data().to_vec()[0];
1420
1421 assert!(
1423 (loss1 - loss2).abs() < 1e-5,
1424 "Duplicated batch should give same loss: {} vs {}",
1425 loss1,
1426 loss2
1427 );
1428 }
1429
1430 #[test]
1431 fn test_cross_entropy_high_class_count() {
1432 let n_classes = 100;
1434 let mut logits = vec![0.0f32; n_classes];
1435 logits[42] = 5.0; let loss_fn = CrossEntropyLoss::new();
1438 let input = Variable::new(Tensor::from_vec(logits, &[1, n_classes]).unwrap(), false);
1439 let target = Variable::new(Tensor::from_vec(vec![42.0], &[1]).unwrap(), false);
1440 let loss = loss_fn.compute(&input, &target);
1441 let val = loss.data().to_vec()[0];
1442 assert!(val.is_finite(), "Should handle 100 classes");
1443 assert!(val < 1.0, "Correct class should have low loss, got {}", val);
1444 }
1445}