Skip to main content

entrenar/train/loss/
mod.rs

1//! Loss functions for training
2//!
3//! This module provides various loss functions for neural network training:
4//!
5//! - [`MSELoss`] - Mean Squared Error for regression
6//! - [`L1Loss`] - Mean Absolute Error (more robust to outliers)
7//! - [`HuberLoss`] / [`SmoothL1Loss`] - Smooth combination of MSE and L1
8//! - [`CrossEntropyLoss`] - For single-label classification tasks
9//! - [`BCEWithLogitsLoss`] - For multi-label classification (sigmoid per class)
10//! - [`CausalLMLoss`] - For autoregressive language modeling
11//! - [`WeightedLoss`] - Scalar weighting wrapper
12//! - [`SampleWeightedLoss`] - Per-sample weighting for curriculum learning
13
14#[cfg(test)]
15mod accuracy_tests;
16mod bce_with_logits;
17mod causal_lm;
18mod cross_entropy;
19mod mse;
20mod traits;
21mod weighted;
22
23pub use bce_with_logits::BCEWithLogitsLoss;
24pub use causal_lm::CausalLMLoss;
25pub use cross_entropy::CrossEntropyLoss;
26pub use mse::{HuberLoss, L1Loss, MSELoss, SmoothL1Loss};
27pub use traits::LossFn;
28pub use weighted::{SampleWeightedLoss, WeightedLoss};
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33
34    #[test]
35    fn test_loss_names() {
36        assert_eq!(MSELoss.name(), "MSE");
37        assert_eq!(CrossEntropyLoss.name(), "CrossEntropy");
38        assert_eq!(BCEWithLogitsLoss.name(), "BCEWithLogits");
39        assert_eq!(HuberLoss::new(1.0).name(), "Huber");
40        assert_eq!(L1Loss.name(), "L1");
41        assert_eq!(WeightedLoss::new(Box::new(MSELoss), 1.0).name(), "Weighted");
42        assert_eq!(SampleWeightedLoss::new(Box::new(MSELoss)).name(), "SampleWeighted");
43        assert_eq!(CausalLMLoss::new(10).name(), "CausalLM");
44    }
45}