Skip to main content

axonml_nn/
loss.rs

1//! Loss Functions - Training Objectives
2//!
3//! Provides loss functions for training neural networks.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_autograd::Variable;
9use axonml_tensor::Tensor;
10
11use crate::module::Module;
12
13// =============================================================================
14// Reduction Enum
15// =============================================================================
16
17/// Specifies how to reduce the loss over elements.
18#[derive(Debug, Clone, Copy, PartialEq, Default)]
19pub enum Reduction {
20    /// No reduction - return loss per element.
21    None,
22    /// Mean of all losses.
23    #[default]
24    Mean,
25    /// Sum of all losses.
26    Sum,
27}
28
29// =============================================================================
30// MSELoss
31// =============================================================================
32
33/// Mean Squared Error loss.
34///
35/// loss = mean((input - target)^2)
36#[derive(Debug, Clone, Copy)]
37pub struct MSELoss {
38    reduction: Reduction,
39}
40
41impl MSELoss {
42    /// Creates a new MSELoss with default reduction (Mean).
43    pub fn new() -> Self {
44        Self {
45            reduction: Reduction::Mean,
46        }
47    }
48
49    /// Creates MSELoss with specified reduction.
50    pub fn with_reduction(reduction: Reduction) -> Self {
51        Self { reduction }
52    }
53
54    /// Computes the loss.
55    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
56        let diff = input.sub_var(target);
57        let squared = diff.pow(2.0);
58
59        match self.reduction {
60            Reduction::None => squared,
61            Reduction::Mean => squared.mean(),
62            Reduction::Sum => squared.sum(),
63        }
64    }
65}
66
67impl Default for MSELoss {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl Module for MSELoss {
74    fn forward(&self, input: &Variable) -> Variable {
75        // For Module interface, we can't easily pass two inputs
76        // This is primarily used via compute() method
77        input.clone()
78    }
79
80    fn name(&self) -> &'static str {
81        "MSELoss"
82    }
83}
84
85// =============================================================================
86// L1Loss
87// =============================================================================
88
89/// Mean Absolute Error loss.
90///
91/// loss = mean(|input - target|)
92#[derive(Debug, Clone, Copy)]
93pub struct L1Loss {
94    reduction: Reduction,
95}
96
97impl L1Loss {
98    /// Creates a new L1Loss with default reduction (Mean).
99    pub fn new() -> Self {
100        Self {
101            reduction: Reduction::Mean,
102        }
103    }
104
105    /// Creates L1Loss with specified reduction.
106    pub fn with_reduction(reduction: Reduction) -> Self {
107        Self { reduction }
108    }
109
110    /// Computes the loss.
111    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
112        let diff = input.sub_var(target);
113        let diff_data = diff.data();
114        let abs_data: Vec<f32> = diff_data.to_vec().iter().map(|x| x.abs()).collect();
115        let abs_tensor = Tensor::from_vec(abs_data, diff_data.shape()).unwrap();
116        let abs_var = Variable::new(abs_tensor, diff.requires_grad());
117
118        match self.reduction {
119            Reduction::None => abs_var,
120            Reduction::Mean => abs_var.mean(),
121            Reduction::Sum => abs_var.sum(),
122        }
123    }
124}
125
126impl Default for L1Loss {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132// =============================================================================
133// CrossEntropyLoss
134// =============================================================================
135
136/// Cross entropy loss with log softmax.
137///
138/// This combines LogSoftmax and NLLLoss in a single class.
139///
140/// # Shape
141/// - Input: (N, C) where C = number of classes
142/// - Target: (N,) with class indices
143#[derive(Debug, Clone, Copy)]
144pub struct CrossEntropyLoss {
145    reduction: Reduction,
146}
147
148impl CrossEntropyLoss {
149    /// Creates a new CrossEntropyLoss with default reduction (Mean).
150    pub fn new() -> Self {
151        Self {
152            reduction: Reduction::Mean,
153        }
154    }
155
156    /// Creates CrossEntropyLoss with specified reduction.
157    pub fn with_reduction(reduction: Reduction) -> Self {
158        Self { reduction }
159    }
160
161    /// Computes the loss.
162    ///
163    /// # Arguments
164    /// * `input` - Logits of shape (N, C)
165    /// * `target` - Class indices of shape (N,) as f32 (will be cast to usize)
166    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
167        let input_data = input.data();
168        let target_data = target.data();
169        let shape = input_data.shape().to_vec();
170        let batch_size = shape[0];
171        let num_classes = shape[1];
172
173        let input_vec = input_data.to_vec();
174        let target_vec = target_data.to_vec();
175
176        let mut losses = vec![0.0f32; batch_size];
177
178        for b in 0..batch_size {
179            // Log softmax
180            let offset = b * num_classes;
181            let max_val = (0..num_classes)
182                .map(|c| input_vec[offset + c])
183                .fold(f32::NEG_INFINITY, f32::max);
184
185            let mut log_sum_exp = 0.0f32;
186            for c in 0..num_classes {
187                log_sum_exp += (input_vec[offset + c] - max_val).exp();
188            }
189            log_sum_exp = max_val + log_sum_exp.ln();
190
191            // NLL loss
192            let target_class = target_vec[b] as usize;
193            losses[b] = log_sum_exp - input_vec[offset + target_class];
194        }
195
196        let loss_tensor = Tensor::from_vec(losses.clone(), &[batch_size]).unwrap();
197        let loss_var = Variable::new(loss_tensor, input.requires_grad());
198
199        match self.reduction {
200            Reduction::None => loss_var,
201            Reduction::Mean => loss_var.mean(),
202            Reduction::Sum => loss_var.sum(),
203        }
204    }
205}
206
207impl Default for CrossEntropyLoss {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213// =============================================================================
214// NLLLoss
215// =============================================================================
216
217/// Negative Log Likelihood loss.
218///
219/// Expects input to be log-probabilities.
220#[derive(Debug, Clone, Copy)]
221pub struct NLLLoss {
222    reduction: Reduction,
223}
224
225impl NLLLoss {
226    /// Creates a new NLLLoss with default reduction (Mean).
227    pub fn new() -> Self {
228        Self {
229            reduction: Reduction::Mean,
230        }
231    }
232
233    /// Creates NLLLoss with specified reduction.
234    pub fn with_reduction(reduction: Reduction) -> Self {
235        Self { reduction }
236    }
237
238    /// Computes the loss.
239    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
240        let input_data = input.data();
241        let target_data = target.data();
242        let shape = input_data.shape().to_vec();
243        let batch_size = shape[0];
244        let num_classes = shape[1];
245
246        let input_vec = input_data.to_vec();
247        let target_vec = target_data.to_vec();
248
249        let mut losses = vec![0.0f32; batch_size];
250
251        for b in 0..batch_size {
252            let target_class = target_vec[b] as usize;
253            losses[b] = -input_vec[b * num_classes + target_class];
254        }
255
256        let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
257        let loss_var = Variable::new(loss_tensor, input.requires_grad());
258
259        match self.reduction {
260            Reduction::None => loss_var,
261            Reduction::Mean => loss_var.mean(),
262            Reduction::Sum => loss_var.sum(),
263        }
264    }
265}
266
267impl Default for NLLLoss {
268    fn default() -> Self {
269        Self::new()
270    }
271}
272
273// =============================================================================
274// BCELoss
275// =============================================================================
276
277/// Binary Cross Entropy loss.
278///
279/// Expects input to be probabilities in [0, 1].
280#[derive(Debug, Clone, Copy)]
281pub struct BCELoss {
282    reduction: Reduction,
283}
284
285impl BCELoss {
286    /// Creates a new BCELoss with default reduction (Mean).
287    pub fn new() -> Self {
288        Self {
289            reduction: Reduction::Mean,
290        }
291    }
292
293    /// Creates BCELoss with specified reduction.
294    pub fn with_reduction(reduction: Reduction) -> Self {
295        Self { reduction }
296    }
297
298    /// Computes the loss.
299    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
300        let eps = 1e-7f32;
301        let input_data = input.data();
302        let target_data = target.data();
303
304        let input_vec = input_data.to_vec();
305        let target_vec = target_data.to_vec();
306
307        let losses: Vec<f32> = input_vec
308            .iter()
309            .zip(target_vec.iter())
310            .map(|(&p, &t)| {
311                let p_clamped = p.max(eps).min(1.0 - eps);
312                -(t * p_clamped.ln() + (1.0 - t) * (1.0 - p_clamped).ln())
313            })
314            .collect();
315
316        let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
317        let loss_var = Variable::new(loss_tensor, input.requires_grad());
318
319        match self.reduction {
320            Reduction::None => loss_var,
321            Reduction::Mean => loss_var.mean(),
322            Reduction::Sum => loss_var.sum(),
323        }
324    }
325}
326
327impl Default for BCELoss {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333// =============================================================================
334// BCEWithLogitsLoss
335// =============================================================================
336
337/// Binary Cross Entropy with Logits.
338///
339/// Combines sigmoid and BCE in a numerically stable way.
340#[derive(Debug, Clone, Copy)]
341pub struct BCEWithLogitsLoss {
342    reduction: Reduction,
343}
344
345impl BCEWithLogitsLoss {
346    /// Creates a new BCEWithLogitsLoss with default reduction (Mean).
347    pub fn new() -> Self {
348        Self {
349            reduction: Reduction::Mean,
350        }
351    }
352
353    /// Creates BCEWithLogitsLoss with specified reduction.
354    pub fn with_reduction(reduction: Reduction) -> Self {
355        Self { reduction }
356    }
357
358    /// Computes the loss.
359    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
360        let input_data = input.data();
361        let target_data = target.data();
362
363        let input_vec = input_data.to_vec();
364        let target_vec = target_data.to_vec();
365
366        // Numerically stable: max(x, 0) - x*t + log(1 + exp(-|x|))
367        let losses: Vec<f32> = input_vec
368            .iter()
369            .zip(target_vec.iter())
370            .map(|(&x, &t)| {
371                let max_val = x.max(0.0);
372                max_val - x * t + (1.0 + (-x.abs()).exp()).ln()
373            })
374            .collect();
375
376        let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
377        let loss_var = Variable::new(loss_tensor, input.requires_grad());
378
379        match self.reduction {
380            Reduction::None => loss_var,
381            Reduction::Mean => loss_var.mean(),
382            Reduction::Sum => loss_var.sum(),
383        }
384    }
385}
386
387impl Default for BCEWithLogitsLoss {
388    fn default() -> Self {
389        Self::new()
390    }
391}
392
393// =============================================================================
394// SmoothL1Loss
395// =============================================================================
396
397/// Smooth L1 loss (Huber loss).
398///
399/// Uses L2 loss when |x| < beta, L1 loss otherwise.
400#[derive(Debug, Clone, Copy)]
401pub struct SmoothL1Loss {
402    reduction: Reduction,
403    beta: f32,
404}
405
406impl SmoothL1Loss {
407    /// Creates a new SmoothL1Loss with default beta (1.0).
408    pub fn new() -> Self {
409        Self {
410            reduction: Reduction::Mean,
411            beta: 1.0,
412        }
413    }
414
415    /// Creates SmoothL1Loss with specified beta.
416    pub fn with_beta(beta: f32) -> Self {
417        Self {
418            reduction: Reduction::Mean,
419            beta,
420        }
421    }
422
423    /// Computes the loss.
424    pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
425        let diff = input.sub_var(target);
426        let diff_data = diff.data();
427        let diff_vec = diff_data.to_vec();
428
429        let losses: Vec<f32> = diff_vec
430            .iter()
431            .map(|&d| {
432                let abs_d = d.abs();
433                if abs_d < self.beta {
434                    0.5 * d * d / self.beta
435                } else {
436                    abs_d - 0.5 * self.beta
437                }
438            })
439            .collect();
440
441        let loss_tensor = Tensor::from_vec(losses, diff_data.shape()).unwrap();
442        let loss_var = Variable::new(loss_tensor, diff.requires_grad());
443
444        match self.reduction {
445            Reduction::None => loss_var,
446            Reduction::Mean => loss_var.mean(),
447            Reduction::Sum => loss_var.sum(),
448        }
449    }
450}
451
452impl Default for SmoothL1Loss {
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458// =============================================================================
459// Tests
460// =============================================================================
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_mse_loss() {
468        let loss_fn = MSELoss::new();
469        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
470        let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
471        let loss = loss_fn.compute(&input, &target);
472        assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
473    }
474
475    #[test]
476    fn test_mse_loss_nonzero() {
477        let loss_fn = MSELoss::new();
478        let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
479        let target = Variable::new(Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap(), false);
480        let loss = loss_fn.compute(&input, &target);
481        // Each diff is 1.0, squared is 1.0, mean is 1.0
482        assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
483    }
484
485    #[test]
486    fn test_cross_entropy_loss() {
487        let loss_fn = CrossEntropyLoss::new();
488        let input = Variable::new(
489            Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
490            false,
491        );
492        let target = Variable::new(Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(), false);
493        let loss = loss_fn.compute(&input, &target);
494        assert!(loss.data().to_vec()[0] > 0.0);
495    }
496
497    #[test]
498    fn test_bce_loss() {
499        let loss_fn = BCELoss::new();
500        let input = Variable::new(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap(), false);
501        let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
502        let loss = loss_fn.compute(&input, &target);
503        // -[1*ln(0.5) + 0*ln(0.5)] - [0*ln(0.5) + 1*ln(0.5)] = -2*ln(0.5) / 2 = -ln(0.5) = 0.693
504        assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
505    }
506}