Skip to main content

axonml_nn/
loss.rs

1//! Loss Functions - Training Objectives
2//!
3//! Provides loss functions for training neural networks.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::any::Any;
9
10use axonml_autograd::{GradFn, GradientFunction, Variable};
11use axonml_tensor::Tensor;
12
13use crate::module::Module;
14
15// =============================================================================
16// Reduction Enum
17// =============================================================================
18
19/// Specifies how to reduce the loss over elements.
20#[derive(Debug, Clone, Copy, PartialEq, Default)]
21pub enum Reduction {
22    /// No reduction - return loss per element.
23    None,
24    /// Mean of all losses.
25    #[default]
26    Mean,
27    /// Sum of all losses.
28    Sum,
29}
30
31// =============================================================================
32// MSELoss
33// =============================================================================
34
35/// Mean Squared Error loss.
36///
37/// loss = mean((input - target)^2)
38#[derive(Debug, Clone, Copy)]
39pub struct MSELoss {
40    reduction: Reduction,
41}
42
43impl MSELoss {
44    /// Creates a new MSELoss with default reduction (Mean).
45    pub fn new() -> Self {
46        Self {
47            reduction: Reduction::Mean,
48        }
49    }
50
51    /// Creates MSELoss with specified reduction.
52    pub fn with_reduction(reduction: Reduction) -> Self {
53        Self { reduction }
54    }
55
56    /// Computes the loss.
57    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
58        let diff = input.sub_var(target);
59        let squared = diff.pow(2.0);
60
61        match self.reduction {
62            Reduction::None => squared,
63            Reduction::Mean => squared.mean(),
64            Reduction::Sum => squared.sum(),
65        }
66    }
67}
68
69impl Default for MSELoss {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl Module for MSELoss {
76    fn forward(&self, input: &Variable) -> Variable {
77        // For Module interface, we can't easily pass two inputs
78        // This is primarily used via compute() method
79        input.clone()
80    }
81
82    fn name(&self) -> &'static str {
83        "MSELoss"
84    }
85}
86
87// =============================================================================
88// L1Loss
89// =============================================================================
90
91/// Mean Absolute Error loss.
92///
93/// loss = mean(|input - target|)
94#[derive(Debug, Clone, Copy)]
95pub struct L1Loss {
96    reduction: Reduction,
97}
98
99impl L1Loss {
100    /// Creates a new L1Loss with default reduction (Mean).
101    pub fn new() -> Self {
102        Self {
103            reduction: Reduction::Mean,
104        }
105    }
106
107    /// Creates L1Loss with specified reduction.
108    pub fn with_reduction(reduction: Reduction) -> Self {
109        Self { reduction }
110    }
111
112    /// Computes the loss.
113    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
114        let diff = input.sub_var(target);
115        let diff_data = diff.data();
116        let abs_data: Vec<f32> = diff_data.to_vec().iter().map(|x| x.abs()).collect();
117        let abs_tensor = Tensor::from_vec(abs_data, diff_data.shape()).unwrap();
118        let abs_var = Variable::new(abs_tensor, diff.requires_grad());
119
120        match self.reduction {
121            Reduction::None => abs_var,
122            Reduction::Mean => abs_var.mean(),
123            Reduction::Sum => abs_var.sum(),
124        }
125    }
126}
127
128impl Default for L1Loss {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134// =============================================================================
135// CrossEntropyBackward
136// =============================================================================
137
138/// Gradient function for CrossEntropyLoss.
139///
140/// The gradient of CE w.r.t. logits is: softmax(logits) - one_hot(target).
141/// For per-sample losses, each sample's gradient is scaled by the upstream
142/// gradient (from reduction).
143#[derive(Debug)]
144struct CrossEntropyBackward {
145    next_fns: Vec<Option<GradFn>>,
146    /// Softmax probabilities computed during forward pass, shape (N, C)
147    softmax_probs: Vec<f32>,
148    /// Target class indices, shape (N,)
149    target_classes: Vec<usize>,
150    batch_size: usize,
151    num_classes: usize,
152}
153
154impl GradientFunction for CrossEntropyBackward {
155    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
156        let grad_vec = grad_output.to_vec();
157        let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
158
159        for b in 0..self.batch_size {
160            let grad_scale = grad_vec[b];
161            let offset = b * self.num_classes;
162            for c in 0..self.num_classes {
163                // d(CE)/d(logit[b,c]) = softmax[b,c] - (1 if c == target[b])
164                let mut g = self.softmax_probs[offset + c];
165                if c == self.target_classes[b] {
166                    g -= 1.0;
167                }
168                grad_input[offset + c] = g * grad_scale;
169            }
170        }
171
172        let grad_tensor =
173            Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes]).unwrap();
174        vec![Some(grad_tensor)]
175    }
176
177    fn name(&self) -> &'static str {
178        "CrossEntropyBackward"
179    }
180
181    fn next_functions(&self) -> &[Option<GradFn>] {
182        &self.next_fns
183    }
184
185    fn as_any(&self) -> &dyn Any {
186        self
187    }
188}
189
190// =============================================================================
191// CrossEntropyLoss
192// =============================================================================
193
194/// Cross entropy loss with log softmax.
195///
196/// This combines LogSoftmax and NLLLoss in a single class.
197///
198/// # Shape
199/// - Input: (N, C) where C = number of classes
200/// - Target: (N,) with class indices
201#[derive(Debug, Clone, Copy)]
202pub struct CrossEntropyLoss {
203    reduction: Reduction,
204}
205
206impl CrossEntropyLoss {
207    /// Creates a new CrossEntropyLoss with default reduction (Mean).
208    pub fn new() -> Self {
209        Self {
210            reduction: Reduction::Mean,
211        }
212    }
213
214    /// Creates CrossEntropyLoss with specified reduction.
215    pub fn with_reduction(reduction: Reduction) -> Self {
216        Self { reduction }
217    }
218
219    /// Computes the loss.
220    ///
221    /// # Arguments
222    /// * `input` - Logits of shape (N, C)
223    /// * `target` - Class indices of shape (N,) as f32 (will be cast to usize)
224    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
225        let input_data = input.data();
226        let target_data = target.data();
227        let shape = input_data.shape().to_vec();
228        let batch_size = shape[0];
229        let num_classes = shape[1];
230
231        let input_vec = input_data.to_vec();
232        let target_vec = target_data.to_vec();
233
234        let mut losses = vec![0.0f32; batch_size];
235        let mut softmax_probs = vec![0.0f32; batch_size * num_classes];
236        let mut target_classes = vec![0usize; batch_size];
237
238        for b in 0..batch_size {
239            let offset = b * num_classes;
240
241            // Numerically stable log-softmax
242            let max_val = (0..num_classes)
243                .map(|c| input_vec[offset + c])
244                .fold(f32::NEG_INFINITY, f32::max);
245
246            let mut sum_exp = 0.0f32;
247            for c in 0..num_classes {
248                let exp_val = (input_vec[offset + c] - max_val).exp();
249                softmax_probs[offset + c] = exp_val;
250                sum_exp += exp_val;
251            }
252
253            // Normalize to get softmax probabilities (saved for backward)
254            for c in 0..num_classes {
255                softmax_probs[offset + c] /= sum_exp;
256            }
257
258            let log_sum_exp = max_val + sum_exp.ln();
259
260            // NLL loss: -log_softmax[target_class]
261            let tc = target_vec[b] as usize;
262            target_classes[b] = tc;
263            losses[b] = log_sum_exp - input_vec[offset + tc];
264        }
265
266        let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
267
268        let loss_var = if input.requires_grad() {
269            let grad_fn = GradFn::new(CrossEntropyBackward {
270                next_fns: vec![input.grad_fn().cloned()],
271                softmax_probs,
272                target_classes,
273                batch_size,
274                num_classes,
275            });
276            Variable::from_operation(loss_tensor, grad_fn, true)
277        } else {
278            Variable::new(loss_tensor, false)
279        };
280
281        match self.reduction {
282            Reduction::None => loss_var,
283            Reduction::Mean => loss_var.mean(),
284            Reduction::Sum => loss_var.sum(),
285        }
286    }
287}
288
289impl Default for CrossEntropyLoss {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295// =============================================================================
296// NLLLoss
297// =============================================================================
298
299/// Negative Log Likelihood loss.
300///
301/// Expects input to be log-probabilities.
302#[derive(Debug, Clone, Copy)]
303pub struct NLLLoss {
304    reduction: Reduction,
305}
306
307impl NLLLoss {
308    /// Creates a new NLLLoss with default reduction (Mean).
309    pub fn new() -> Self {
310        Self {
311            reduction: Reduction::Mean,
312        }
313    }
314
315    /// Creates NLLLoss with specified reduction.
316    pub fn with_reduction(reduction: Reduction) -> Self {
317        Self { reduction }
318    }
319
320    /// Computes the loss.
321    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
322        let input_data = input.data();
323        let target_data = target.data();
324        let shape = input_data.shape().to_vec();
325        let batch_size = shape[0];
326        let num_classes = shape[1];
327
328        let input_vec = input_data.to_vec();
329        let target_vec = target_data.to_vec();
330
331        let mut losses = vec![0.0f32; batch_size];
332
333        for b in 0..batch_size {
334            let target_class = target_vec[b] as usize;
335            losses[b] = -input_vec[b * num_classes + target_class];
336        }
337
338        let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
339        let loss_var = Variable::new(loss_tensor, input.requires_grad());
340
341        match self.reduction {
342            Reduction::None => loss_var,
343            Reduction::Mean => loss_var.mean(),
344            Reduction::Sum => loss_var.sum(),
345        }
346    }
347}
348
349impl Default for NLLLoss {
350    fn default() -> Self {
351        Self::new()
352    }
353}
354
355// =============================================================================
356// BCELoss
357// =============================================================================
358
359/// Binary Cross Entropy loss.
360///
361/// Expects input to be probabilities in [0, 1].
362#[derive(Debug, Clone, Copy)]
363pub struct BCELoss {
364    reduction: Reduction,
365}
366
367impl BCELoss {
368    /// Creates a new BCELoss with default reduction (Mean).
369    pub fn new() -> Self {
370        Self {
371            reduction: Reduction::Mean,
372        }
373    }
374
375    /// Creates BCELoss with specified reduction.
376    pub fn with_reduction(reduction: Reduction) -> Self {
377        Self { reduction }
378    }
379
380    /// Computes the loss.
381    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
382        let eps = 1e-7f32;
383        let input_data = input.data();
384        let target_data = target.data();
385
386        let input_vec = input_data.to_vec();
387        let target_vec = target_data.to_vec();
388
389        let losses: Vec<f32> = input_vec
390            .iter()
391            .zip(target_vec.iter())
392            .map(|(&p, &t)| {
393                let p_clamped = p.max(eps).min(1.0 - eps);
394                -(t * p_clamped.ln() + (1.0 - t) * (1.0 - p_clamped).ln())
395            })
396            .collect();
397
398        let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
399        let loss_var = Variable::new(loss_tensor, input.requires_grad());
400
401        match self.reduction {
402            Reduction::None => loss_var,
403            Reduction::Mean => loss_var.mean(),
404            Reduction::Sum => loss_var.sum(),
405        }
406    }
407}
408
409impl Default for BCELoss {
410    fn default() -> Self {
411        Self::new()
412    }
413}
414
415// =============================================================================
416// BCEWithLogitsLoss
417// =============================================================================
418
419/// Binary Cross Entropy with Logits.
420///
421/// Combines sigmoid and BCE in a numerically stable way.
422#[derive(Debug, Clone, Copy)]
423pub struct BCEWithLogitsLoss {
424    reduction: Reduction,
425}
426
427impl BCEWithLogitsLoss {
428    /// Creates a new BCEWithLogitsLoss with default reduction (Mean).
429    pub fn new() -> Self {
430        Self {
431            reduction: Reduction::Mean,
432        }
433    }
434
435    /// Creates BCEWithLogitsLoss with specified reduction.
436    pub fn with_reduction(reduction: Reduction) -> Self {
437        Self { reduction }
438    }
439
440    /// Computes the loss.
441    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
442        let input_data = input.data();
443        let target_data = target.data();
444
445        let input_vec = input_data.to_vec();
446        let target_vec = target_data.to_vec();
447
448        // Numerically stable: max(x, 0) - x*t + log(1 + exp(-|x|))
449        let losses: Vec<f32> = input_vec
450            .iter()
451            .zip(target_vec.iter())
452            .map(|(&x, &t)| {
453                let max_val = x.max(0.0);
454                max_val - x * t + (1.0 + (-x.abs()).exp()).ln()
455            })
456            .collect();
457
458        let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
459        let loss_var = Variable::new(loss_tensor, input.requires_grad());
460
461        match self.reduction {
462            Reduction::None => loss_var,
463            Reduction::Mean => loss_var.mean(),
464            Reduction::Sum => loss_var.sum(),
465        }
466    }
467}
468
469impl Default for BCEWithLogitsLoss {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475// =============================================================================
476// SmoothL1Loss
477// =============================================================================
478
479/// Smooth L1 loss (Huber loss).
480///
481/// Uses L2 loss when |x| < beta, L1 loss otherwise.
482#[derive(Debug, Clone, Copy)]
483pub struct SmoothL1Loss {
484    reduction: Reduction,
485    beta: f32,
486}
487
488impl SmoothL1Loss {
489    /// Creates a new SmoothL1Loss with default beta (1.0).
490    pub fn new() -> Self {
491        Self {
492            reduction: Reduction::Mean,
493            beta: 1.0,
494        }
495    }
496
497    /// Creates SmoothL1Loss with specified beta.
498    pub fn with_beta(beta: f32) -> Self {
499        Self {
500            reduction: Reduction::Mean,
501            beta,
502        }
503    }
504
505    /// Computes the loss.
506    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
507        let diff = input.sub_var(target);
508        let diff_data = diff.data();
509        let diff_vec = diff_data.to_vec();
510
511        let losses: Vec<f32> = diff_vec
512            .iter()
513            .map(|&d| {
514                let abs_d = d.abs();
515                if abs_d < self.beta {
516                    0.5 * d * d / self.beta
517                } else {
518                    abs_d - 0.5 * self.beta
519                }
520            })
521            .collect();
522
523        let loss_tensor = Tensor::from_vec(losses, diff_data.shape()).unwrap();
524        let loss_var = Variable::new(loss_tensor, diff.requires_grad());
525
526        match self.reduction {
527            Reduction::None => loss_var,
528            Reduction::Mean => loss_var.mean(),
529            Reduction::Sum => loss_var.sum(),
530        }
531    }
532}
533
534impl Default for SmoothL1Loss {
535    fn default() -> Self {
536        Self::new()
537    }
538}
539
540// =============================================================================
541// Tests
542// =============================================================================
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_mse_loss() {
550        let loss_fn = MSELoss::new();
551        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
552        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
553        let loss = loss_fn.compute(&input, &target);
554        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
555    }
556
557    #[test]
558    fn test_mse_loss_nonzero() {
559        let loss_fn = MSELoss::new();
560        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
561        let target = Variable::new(Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap(), false);
562        let loss = loss_fn.compute(&input, &target);
563        // Each diff is 1.0, squared is 1.0, mean is 1.0
564        assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
565    }
566
567    #[test]
568    fn test_cross_entropy_loss() {
569        let loss_fn = CrossEntropyLoss::new();
570        let input = Variable::new(
571            Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
572            false,
573        );
574        let target = Variable::new(Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(), false);
575        let loss = loss_fn.compute(&input, &target);
576        assert!(loss.data().to_vec()[0] > 0.0);
577    }
578
579    #[test]
580    fn test_bce_loss() {
581        let loss_fn = BCELoss::new();
582        let input = Variable::new(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap(), false);
583        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
584        let loss = loss_fn.compute(&input, &target);
585        // -[1*ln(0.5) + 0*ln(0.5)] - [0*ln(0.5) + 1*ln(0.5)] = -2*ln(0.5) / 2 = -ln(0.5) = 0.693
586        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
587    }
588
589    #[test]
590    fn test_cross_entropy_gradient_flow() {
591        use axonml_autograd::backward;
592
593        // Create input logits with requires_grad=true
594        let input = Variable::new(
595            Tensor::from_vec(
596                vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3],
597                &[2, 3],
598            )
599            .unwrap(),
600            true,
601        );
602        let target = Variable::new(
603            Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap(),
604            false,
605        );
606
607        let loss_fn = CrossEntropyLoss::new();
608        let loss = loss_fn.compute(&input, &target);
609
610        // Loss should be positive
611        let loss_val = loss.data().to_vec()[0];
612        assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
613
614        // Backward pass
615        let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
616        backward(&loss, &ones);
617
618        // Input should have gradient
619        let grad = input.grad().expect("Input should have gradient after backward");
620        let grad_vec = grad.to_vec();
621
622        // Gradient should be non-zero
623        let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
624        assert!(
625            grad_norm > 1e-10,
626            "Gradient should be non-zero, got norm {}",
627            grad_norm
628        );
629
630        // Gradient shape should match input shape
631        assert_eq!(grad.shape(), &[2, 3]);
632
633        // For the correct class, gradient should be negative (softmax - 1 < 0)
634        // Sample 0, class 0 (target): grad should be (softmax[0,0] - 1) / 2
635        assert!(grad_vec[0] < 0.0, "Gradient for correct class should be negative");
636        // Sample 1, class 1 (target): grad should be (softmax[1,1] - 1) / 2
637        assert!(grad_vec[4] < 0.0, "Gradient for correct class should be negative");
638
639        // For wrong classes, gradient should be positive (softmax > 0)
640        assert!(grad_vec[1] > 0.0, "Gradient for wrong class should be positive");
641        assert!(grad_vec[2] > 0.0, "Gradient for wrong class should be positive");
642    }
643
644    #[test]
645    fn test_cross_entropy_perfect_prediction() {
646        // When logits strongly favor the correct class, loss should be near zero
647        let loss_fn = CrossEntropyLoss::new();
648        let input = Variable::new(
649            Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).unwrap(),
650            false,
651        );
652        let target = Variable::new(
653            Tensor::from_vec(vec![0.0], &[1]).unwrap(),
654            false,
655        );
656        let loss = loss_fn.compute(&input, &target);
657        assert!(loss.data().to_vec()[0] < 0.001, "Perfect prediction should have near-zero loss");
658    }
659
660    #[test]
661    fn test_cross_entropy_uniform_prediction() {
662        // When logits are all equal, loss should be ln(num_classes)
663        let loss_fn = CrossEntropyLoss::new();
664        let num_classes = 16;
665        let input = Variable::new(
666            Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes]).unwrap(),
667            false,
668        );
669        let target = Variable::new(
670            Tensor::from_vec(vec![0.0], &[1]).unwrap(),
671            false,
672        );
673        let loss = loss_fn.compute(&input, &target);
674        let expected = (num_classes as f32).ln(); // ln(16) ≈ 2.77
675        let actual = loss.data().to_vec()[0];
676        assert!(
677            (actual - expected).abs() < 0.01,
678            "Uniform logits should give ln(C)={}, got {}",
679            expected,
680            actual,
681        );
682    }
683}