irithyll_core/loss/
squared.rs1use super::Loss;
10
11#[derive(Debug, Clone, Copy)]
17pub struct SquaredLoss;
18
19impl Loss for SquaredLoss {
20 #[inline]
21 fn n_outputs(&self) -> usize {
22 1
23 }
24
25 #[inline]
26 fn gradient(&self, target: f64, prediction: f64) -> f64 {
27 prediction - target
28 }
29
30 #[inline]
31 fn hessian(&self, _target: f64, _prediction: f64) -> f64 {
32 1.0
33 }
34
35 #[inline]
36 fn loss(&self, target: f64, prediction: f64) -> f64 {
37 let r = prediction - target;
38 0.5 * r * r
39 }
40
41 #[inline]
42 fn predict_transform(&self, raw: f64) -> f64 {
43 raw
44 }
45
46 fn initial_prediction(&self, targets: &[f64]) -> f64 {
47 if targets.is_empty() {
48 return 0.0;
49 }
50 let sum: f64 = targets.iter().sum();
51 sum / targets.len() as f64
52 }
53
54 fn loss_type(&self) -> Option<super::LossType> {
55 Some(super::LossType::Squared)
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62
63 const EPS: f64 = 1e-12;
64
65 #[test]
66 fn test_n_outputs() {
67 assert_eq!(SquaredLoss.n_outputs(), 1);
68 }
69
70 #[test]
71 fn test_gradient_at_known_points() {
72 let loss = SquaredLoss;
73 assert!((loss.gradient(3.0, 3.0)).abs() < EPS);
74 assert!((loss.gradient(1.0, 4.0) - 3.0).abs() < EPS);
75 assert!((loss.gradient(5.0, 2.0) - (-3.0)).abs() < EPS);
76 }
77
78 #[test]
79 fn test_hessian_is_constant() {
80 let loss = SquaredLoss;
81 assert!((loss.hessian(0.0, 0.0) - 1.0).abs() < EPS);
82 assert!((loss.hessian(100.0, -50.0) - 1.0).abs() < EPS);
83 assert!((loss.hessian(-7.0, 42.0) - 1.0).abs() < EPS);
84 }
85
86 #[test]
87 fn test_loss_value() {
88 let loss = SquaredLoss;
89 assert!((loss.loss(1.0, 3.0) - 2.0).abs() < EPS);
90 assert!((loss.loss(5.0, 5.0)).abs() < EPS);
91 assert!((loss.loss(0.0, 1.0) - 0.5).abs() < EPS);
92 }
93
94 #[test]
95 fn test_predict_transform_is_identity() {
96 let loss = SquaredLoss;
97 assert!((loss.predict_transform(42.0) - 42.0).abs() < EPS);
98 assert!((loss.predict_transform(-3.25) - (-3.25)).abs() < EPS);
99 }
100
101 #[test]
102 fn test_initial_prediction_is_mean() {
103 let loss = SquaredLoss;
104 let targets = [1.0, 2.0, 3.0, 4.0, 5.0];
105 assert!((loss.initial_prediction(&targets) - 3.0).abs() < EPS);
106
107 let single = [7.0];
108 assert!((loss.initial_prediction(&single) - 7.0).abs() < EPS);
109 }
110
111 #[test]
112 fn test_initial_prediction_empty() {
113 let loss = SquaredLoss;
114 assert!((loss.initial_prediction(&[])).abs() < EPS);
115 }
116
117 #[test]
118 fn test_gradient_is_derivative_of_loss() {
119 let loss = SquaredLoss;
120 let target = 2.5;
121 let pred = 4.0;
122 let h = 1e-6;
123 let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
124 let analytical = loss.gradient(target, pred);
125 assert!((numerical - analytical).abs() < 1e-5);
126 }
127}