entrenar/train/loss/
mod.rs1#[cfg(test)]
15mod accuracy_tests;
16mod bce_with_logits;
17mod causal_lm;
18mod cross_entropy;
19mod mse;
20mod traits;
21mod weighted;
22
23pub use bce_with_logits::BCEWithLogitsLoss;
24pub use causal_lm::CausalLMLoss;
25pub use cross_entropy::CrossEntropyLoss;
26pub use mse::{HuberLoss, L1Loss, MSELoss, SmoothL1Loss};
27pub use traits::LossFn;
28pub use weighted::{SampleWeightedLoss, WeightedLoss};
29
30#[cfg(test)]
31mod tests {
32 use super::*;
33
34 #[test]
35 fn test_loss_names() {
36 assert_eq!(MSELoss.name(), "MSE");
37 assert_eq!(CrossEntropyLoss.name(), "CrossEntropy");
38 assert_eq!(BCEWithLogitsLoss.name(), "BCEWithLogits");
39 assert_eq!(HuberLoss::new(1.0).name(), "Huber");
40 assert_eq!(L1Loss.name(), "L1");
41 assert_eq!(WeightedLoss::new(Box::new(MSELoss), 1.0).name(), "Weighted");
42 assert_eq!(SampleWeightedLoss::new(Box::new(MSELoss)).name(), "SampleWeighted");
43 assert_eq!(CausalLMLoss::new(10).name(), "CausalLM");
44 }
45}