1use std::any::Any;
9
10use axonml_autograd::{GradFn, GradientFunction, Variable};
11use axonml_tensor::Tensor;
12
13use crate::module::Module;
14
15#[derive(Debug, Clone, Copy, PartialEq, Default)]
21pub enum Reduction {
22 None,
24 #[default]
26 Mean,
27 Sum,
29}
30
31#[derive(Debug, Clone, Copy)]
39pub struct MSELoss {
40 reduction: Reduction,
41}
42
43impl MSELoss {
44 pub fn new() -> Self {
46 Self {
47 reduction: Reduction::Mean,
48 }
49 }
50
51 pub fn with_reduction(reduction: Reduction) -> Self {
53 Self { reduction }
54 }
55
56 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
58 let diff = input.sub_var(target);
59 let squared = diff.pow(2.0);
60
61 match self.reduction {
62 Reduction::None => squared,
63 Reduction::Mean => squared.mean(),
64 Reduction::Sum => squared.sum(),
65 }
66 }
67}
68
69impl Default for MSELoss {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Module for MSELoss {
76 fn forward(&self, input: &Variable) -> Variable {
77 input.clone()
80 }
81
82 fn name(&self) -> &'static str {
83 "MSELoss"
84 }
85}
86
87#[derive(Debug, Clone, Copy)]
95pub struct L1Loss {
96 reduction: Reduction,
97}
98
99impl L1Loss {
100 pub fn new() -> Self {
102 Self {
103 reduction: Reduction::Mean,
104 }
105 }
106
107 pub fn with_reduction(reduction: Reduction) -> Self {
109 Self { reduction }
110 }
111
112 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
114 let diff = input.sub_var(target);
115 let diff_data = diff.data();
116 let abs_data: Vec<f32> = diff_data.to_vec().iter().map(|x| x.abs()).collect();
117 let abs_tensor = Tensor::from_vec(abs_data, diff_data.shape()).unwrap();
118 let abs_var = Variable::new(abs_tensor, diff.requires_grad());
119
120 match self.reduction {
121 Reduction::None => abs_var,
122 Reduction::Mean => abs_var.mean(),
123 Reduction::Sum => abs_var.sum(),
124 }
125 }
126}
127
128impl Default for L1Loss {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134#[derive(Debug)]
144struct CrossEntropyBackward {
145 next_fns: Vec<Option<GradFn>>,
146 softmax_probs: Vec<f32>,
148 target_classes: Vec<usize>,
150 batch_size: usize,
151 num_classes: usize,
152}
153
154impl GradientFunction for CrossEntropyBackward {
155 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
156 let grad_vec = grad_output.to_vec();
157 let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
158
159 for b in 0..self.batch_size {
160 let grad_scale = grad_vec[b];
161 let offset = b * self.num_classes;
162 for c in 0..self.num_classes {
163 let mut g = self.softmax_probs[offset + c];
165 if c == self.target_classes[b] {
166 g -= 1.0;
167 }
168 grad_input[offset + c] = g * grad_scale;
169 }
170 }
171
172 let grad_tensor =
173 Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes]).unwrap();
174 vec![Some(grad_tensor)]
175 }
176
177 fn name(&self) -> &'static str {
178 "CrossEntropyBackward"
179 }
180
181 fn next_functions(&self) -> &[Option<GradFn>] {
182 &self.next_fns
183 }
184
185 fn as_any(&self) -> &dyn Any {
186 self
187 }
188}
189
190#[derive(Debug, Clone, Copy)]
202pub struct CrossEntropyLoss {
203 reduction: Reduction,
204}
205
206impl CrossEntropyLoss {
207 pub fn new() -> Self {
209 Self {
210 reduction: Reduction::Mean,
211 }
212 }
213
214 pub fn with_reduction(reduction: Reduction) -> Self {
216 Self { reduction }
217 }
218
219 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
225 let input_data = input.data();
226 let target_data = target.data();
227 let shape = input_data.shape().to_vec();
228 let batch_size = shape[0];
229 let num_classes = shape[1];
230
231 let input_vec = input_data.to_vec();
232 let target_vec = target_data.to_vec();
233
234 let mut losses = vec![0.0f32; batch_size];
235 let mut softmax_probs = vec![0.0f32; batch_size * num_classes];
236 let mut target_classes = vec![0usize; batch_size];
237
238 for b in 0..batch_size {
239 let offset = b * num_classes;
240
241 let max_val = (0..num_classes)
243 .map(|c| input_vec[offset + c])
244 .fold(f32::NEG_INFINITY, f32::max);
245
246 let mut sum_exp = 0.0f32;
247 for c in 0..num_classes {
248 let exp_val = (input_vec[offset + c] - max_val).exp();
249 softmax_probs[offset + c] = exp_val;
250 sum_exp += exp_val;
251 }
252
253 for c in 0..num_classes {
255 softmax_probs[offset + c] /= sum_exp;
256 }
257
258 let log_sum_exp = max_val + sum_exp.ln();
259
260 let tc = target_vec[b] as usize;
262 target_classes[b] = tc;
263 losses[b] = log_sum_exp - input_vec[offset + tc];
264 }
265
266 let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
267
268 let loss_var = if input.requires_grad() {
269 let grad_fn = GradFn::new(CrossEntropyBackward {
270 next_fns: vec![input.grad_fn().cloned()],
271 softmax_probs,
272 target_classes,
273 batch_size,
274 num_classes,
275 });
276 Variable::from_operation(loss_tensor, grad_fn, true)
277 } else {
278 Variable::new(loss_tensor, false)
279 };
280
281 match self.reduction {
282 Reduction::None => loss_var,
283 Reduction::Mean => loss_var.mean(),
284 Reduction::Sum => loss_var.sum(),
285 }
286 }
287}
288
289impl Default for CrossEntropyLoss {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[derive(Debug, Clone, Copy)]
303pub struct NLLLoss {
304 reduction: Reduction,
305}
306
307impl NLLLoss {
308 pub fn new() -> Self {
310 Self {
311 reduction: Reduction::Mean,
312 }
313 }
314
315 pub fn with_reduction(reduction: Reduction) -> Self {
317 Self { reduction }
318 }
319
320 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
322 let input_data = input.data();
323 let target_data = target.data();
324 let shape = input_data.shape().to_vec();
325 let batch_size = shape[0];
326 let num_classes = shape[1];
327
328 let input_vec = input_data.to_vec();
329 let target_vec = target_data.to_vec();
330
331 let mut losses = vec![0.0f32; batch_size];
332
333 for b in 0..batch_size {
334 let target_class = target_vec[b] as usize;
335 losses[b] = -input_vec[b * num_classes + target_class];
336 }
337
338 let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
339 let loss_var = Variable::new(loss_tensor, input.requires_grad());
340
341 match self.reduction {
342 Reduction::None => loss_var,
343 Reduction::Mean => loss_var.mean(),
344 Reduction::Sum => loss_var.sum(),
345 }
346 }
347}
348
349impl Default for NLLLoss {
350 fn default() -> Self {
351 Self::new()
352 }
353}
354
355#[derive(Debug, Clone, Copy)]
363pub struct BCELoss {
364 reduction: Reduction,
365}
366
367impl BCELoss {
368 pub fn new() -> Self {
370 Self {
371 reduction: Reduction::Mean,
372 }
373 }
374
375 pub fn with_reduction(reduction: Reduction) -> Self {
377 Self { reduction }
378 }
379
380 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
382 let eps = 1e-7f32;
383 let input_data = input.data();
384 let target_data = target.data();
385
386 let input_vec = input_data.to_vec();
387 let target_vec = target_data.to_vec();
388
389 let losses: Vec<f32> = input_vec
390 .iter()
391 .zip(target_vec.iter())
392 .map(|(&p, &t)| {
393 let p_clamped = p.max(eps).min(1.0 - eps);
394 -(t * p_clamped.ln() + (1.0 - t) * (1.0 - p_clamped).ln())
395 })
396 .collect();
397
398 let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
399 let loss_var = Variable::new(loss_tensor, input.requires_grad());
400
401 match self.reduction {
402 Reduction::None => loss_var,
403 Reduction::Mean => loss_var.mean(),
404 Reduction::Sum => loss_var.sum(),
405 }
406 }
407}
408
409impl Default for BCELoss {
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415#[derive(Debug, Clone, Copy)]
423pub struct BCEWithLogitsLoss {
424 reduction: Reduction,
425}
426
427impl BCEWithLogitsLoss {
428 pub fn new() -> Self {
430 Self {
431 reduction: Reduction::Mean,
432 }
433 }
434
435 pub fn with_reduction(reduction: Reduction) -> Self {
437 Self { reduction }
438 }
439
440 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
442 let input_data = input.data();
443 let target_data = target.data();
444
445 let input_vec = input_data.to_vec();
446 let target_vec = target_data.to_vec();
447
448 let losses: Vec<f32> = input_vec
450 .iter()
451 .zip(target_vec.iter())
452 .map(|(&x, &t)| {
453 let max_val = x.max(0.0);
454 max_val - x * t + (1.0 + (-x.abs()).exp()).ln()
455 })
456 .collect();
457
458 let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
459 let loss_var = Variable::new(loss_tensor, input.requires_grad());
460
461 match self.reduction {
462 Reduction::None => loss_var,
463 Reduction::Mean => loss_var.mean(),
464 Reduction::Sum => loss_var.sum(),
465 }
466 }
467}
468
469impl Default for BCEWithLogitsLoss {
470 fn default() -> Self {
471 Self::new()
472 }
473}
474
475#[derive(Debug, Clone, Copy)]
483pub struct SmoothL1Loss {
484 reduction: Reduction,
485 beta: f32,
486}
487
488impl SmoothL1Loss {
489 pub fn new() -> Self {
491 Self {
492 reduction: Reduction::Mean,
493 beta: 1.0,
494 }
495 }
496
497 pub fn with_beta(beta: f32) -> Self {
499 Self {
500 reduction: Reduction::Mean,
501 beta,
502 }
503 }
504
505 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
507 let diff = input.sub_var(target);
508 let diff_data = diff.data();
509 let diff_vec = diff_data.to_vec();
510
511 let losses: Vec<f32> = diff_vec
512 .iter()
513 .map(|&d| {
514 let abs_d = d.abs();
515 if abs_d < self.beta {
516 0.5 * d * d / self.beta
517 } else {
518 abs_d - 0.5 * self.beta
519 }
520 })
521 .collect();
522
523 let loss_tensor = Tensor::from_vec(losses, diff_data.shape()).unwrap();
524 let loss_var = Variable::new(loss_tensor, diff.requires_grad());
525
526 match self.reduction {
527 Reduction::None => loss_var,
528 Reduction::Mean => loss_var.mean(),
529 Reduction::Sum => loss_var.sum(),
530 }
531 }
532}
533
534impl Default for SmoothL1Loss {
535 fn default() -> Self {
536 Self::new()
537 }
538}
539
540#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn test_mse_loss() {
550 let loss_fn = MSELoss::new();
551 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
552 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
553 let loss = loss_fn.compute(&input, &target);
554 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
555 }
556
557 #[test]
558 fn test_mse_loss_nonzero() {
559 let loss_fn = MSELoss::new();
560 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
561 let target = Variable::new(Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap(), false);
562 let loss = loss_fn.compute(&input, &target);
563 assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
565 }
566
567 #[test]
568 fn test_cross_entropy_loss() {
569 let loss_fn = CrossEntropyLoss::new();
570 let input = Variable::new(
571 Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
572 false,
573 );
574 let target = Variable::new(Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(), false);
575 let loss = loss_fn.compute(&input, &target);
576 assert!(loss.data().to_vec()[0] > 0.0);
577 }
578
579 #[test]
580 fn test_bce_loss() {
581 let loss_fn = BCELoss::new();
582 let input = Variable::new(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap(), false);
583 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
584 let loss = loss_fn.compute(&input, &target);
585 assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
587 }
588
589 #[test]
590 fn test_cross_entropy_gradient_flow() {
591 use axonml_autograd::backward;
592
593 let input = Variable::new(
595 Tensor::from_vec(
596 vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3],
597 &[2, 3],
598 )
599 .unwrap(),
600 true,
601 );
602 let target = Variable::new(
603 Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap(),
604 false,
605 );
606
607 let loss_fn = CrossEntropyLoss::new();
608 let loss = loss_fn.compute(&input, &target);
609
610 let loss_val = loss.data().to_vec()[0];
612 assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
613
614 let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
616 backward(&loss, &ones);
617
618 let grad = input.grad().expect("Input should have gradient after backward");
620 let grad_vec = grad.to_vec();
621
622 let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
624 assert!(
625 grad_norm > 1e-10,
626 "Gradient should be non-zero, got norm {}",
627 grad_norm
628 );
629
630 assert_eq!(grad.shape(), &[2, 3]);
632
633 assert!(grad_vec[0] < 0.0, "Gradient for correct class should be negative");
636 assert!(grad_vec[4] < 0.0, "Gradient for correct class should be negative");
638
639 assert!(grad_vec[1] > 0.0, "Gradient for wrong class should be positive");
641 assert!(grad_vec[2] > 0.0, "Gradient for wrong class should be positive");
642 }
643
644 #[test]
645 fn test_cross_entropy_perfect_prediction() {
646 let loss_fn = CrossEntropyLoss::new();
648 let input = Variable::new(
649 Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).unwrap(),
650 false,
651 );
652 let target = Variable::new(
653 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
654 false,
655 );
656 let loss = loss_fn.compute(&input, &target);
657 assert!(loss.data().to_vec()[0] < 0.001, "Perfect prediction should have near-zero loss");
658 }
659
660 #[test]
661 fn test_cross_entropy_uniform_prediction() {
662 let loss_fn = CrossEntropyLoss::new();
664 let num_classes = 16;
665 let input = Variable::new(
666 Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes]).unwrap(),
667 false,
668 );
669 let target = Variable::new(
670 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
671 false,
672 );
673 let loss = loss_fn.compute(&input, &target);
674 let expected = (num_classes as f32).ln(); let actual = loss.data().to_vec()[0];
676 assert!(
677 (actual - expected).abs() < 0.01,
678 "Uniform logits should give ln(C)={}, got {}",
679 expected,
680 actual,
681 );
682 }
683}