irithyll_core/loss/
quantile.rs1use super::{Loss, LossType};
14
15#[derive(Debug, Clone, Copy)]
22pub struct QuantileLoss {
23 pub tau: f64,
25}
26
27impl QuantileLoss {
28 pub fn new(tau: f64) -> Self {
34 assert!(tau > 0.0 && tau < 1.0, "tau must be in (0, 1), got {tau}");
35 Self { tau }
36 }
37}
38
39impl Loss for QuantileLoss {
40 #[inline]
41 fn n_outputs(&self) -> usize {
42 1
43 }
44
45 #[inline]
46 fn gradient(&self, target: f64, prediction: f64) -> f64 {
47 if prediction >= target {
51 1.0 - self.tau
52 } else {
53 -self.tau
54 }
55 }
56
57 #[inline]
58 fn hessian(&self, _target: f64, _prediction: f64) -> f64 {
59 1.0
62 }
63
64 #[inline]
65 fn loss(&self, target: f64, prediction: f64) -> f64 {
66 let r = target - prediction;
67 if r >= 0.0 {
68 self.tau * r
69 } else {
70 (self.tau - 1.0) * r
71 }
72 }
73
74 #[inline]
75 fn predict_transform(&self, raw: f64) -> f64 {
76 raw
77 }
78
79 fn initial_prediction(&self, targets: &[f64]) -> f64 {
80 if targets.is_empty() {
81 return 0.0;
82 }
83 #[cfg(feature = "alloc")]
84 {
85 let mut sorted: alloc::vec::Vec<f64> = alloc::vec::Vec::from(targets);
87 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
88 let idx = ((self.tau * sorted.len() as f64) as usize).min(sorted.len() - 1);
89 sorted[idx]
90 }
91 #[cfg(not(feature = "alloc"))]
92 {
93 let sum: f64 = targets.iter().sum();
95 sum / targets.len() as f64
96 }
97 }
98
99 fn loss_type(&self) -> Option<LossType> {
100 Some(LossType::Quantile { tau: self.tau })
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 const EPS: f64 = 1e-12;
109
110 #[test]
111 fn test_n_outputs() {
112 assert_eq!(QuantileLoss::new(0.5).n_outputs(), 1);
113 }
114
115 #[test]
116 fn test_gradient_over_predict() {
117 let loss = QuantileLoss::new(0.9);
118 assert!((loss.gradient(1.0, 3.0) - 0.1).abs() < EPS);
119 assert!((loss.gradient(0.0, 100.0) - 0.1).abs() < EPS);
120 }
121
122 #[test]
123 fn test_gradient_under_predict() {
124 let loss = QuantileLoss::new(0.9);
125 assert!((loss.gradient(3.0, 1.0) - (-0.9)).abs() < EPS);
126 assert!((loss.gradient(100.0, 0.0) - (-0.9)).abs() < EPS);
127 }
128
129 #[test]
130 fn test_gradient_at_exact() {
131 let loss = QuantileLoss::new(0.5);
132 assert!((loss.gradient(5.0, 5.0) - 0.5).abs() < EPS);
133 }
134
135 #[test]
136 fn test_hessian_is_one() {
137 let loss = QuantileLoss::new(0.9);
138 assert!((loss.hessian(0.0, 0.0) - 1.0).abs() < EPS);
139 assert!((loss.hessian(100.0, -50.0) - 1.0).abs() < EPS);
140 assert!((loss.hessian(-7.0, 42.0) - 1.0).abs() < EPS);
141 }
142
143 #[test]
144 fn test_loss_pinball() {
145 let loss = QuantileLoss::new(0.9);
146 assert!((loss.loss(5.0, 3.0) - 0.9 * 2.0).abs() < EPS);
147 assert!((loss.loss(3.0, 5.0) - 0.1 * 2.0).abs() < EPS);
148 assert!((loss.loss(4.0, 4.0)).abs() < EPS);
149 }
150
151 #[test]
152 fn test_median_loss_is_half_mae() {
153 let loss = QuantileLoss::new(0.5);
154 assert!((loss.loss(5.0, 3.0) - 1.0).abs() < EPS);
155 assert!((loss.loss(3.0, 5.0) - 1.0).abs() < EPS);
156 }
157
158 #[test]
159 fn test_predict_transform_is_identity() {
160 let loss = QuantileLoss::new(0.5);
161 assert!((loss.predict_transform(42.0) - 42.0).abs() < EPS);
162 }
163
164 #[cfg(feature = "alloc")]
165 #[test]
166 fn test_initial_prediction_is_quantile() {
167 let loss = QuantileLoss::new(0.5);
168 let targets = [1.0, 2.0, 3.0, 4.0, 5.0];
169 assert!((loss.initial_prediction(&targets) - 3.0).abs() < EPS);
170
171 let loss90 = QuantileLoss::new(0.9);
172 assert!((loss90.initial_prediction(&targets) - 5.0).abs() < EPS);
173 }
174
175 #[test]
176 fn test_initial_prediction_empty() {
177 let loss = QuantileLoss::new(0.5);
178 assert!((loss.initial_prediction(&[])).abs() < EPS);
179 }
180
181 #[test]
182 fn test_loss_type_returns_some() {
183 let loss = QuantileLoss::new(0.75);
184 match loss.loss_type() {
185 Some(LossType::Quantile { tau }) => assert!((tau - 0.75).abs() < EPS),
186 other => panic!("expected Quantile, got {other:?}"),
187 }
188 }
189
190 #[test]
191 fn test_gradient_is_subderivative_of_loss() {
192 let loss = QuantileLoss::new(0.75);
193 let target = 2.5;
194
195 let pred = 4.0;
196 let h = 1e-6;
197 let numerical = (loss.loss(target, pred + h) - loss.loss(target, pred - h)) / (2.0 * h);
198 let analytical = loss.gradient(target, pred);
199 assert!(
200 (numerical - analytical).abs() < 1e-4,
201 "over: numerical={numerical}, analytical={analytical}"
202 );
203
204 let pred2 = 1.0;
205 let numerical2 = (loss.loss(target, pred2 + h) - loss.loss(target, pred2 - h)) / (2.0 * h);
206 let analytical2 = loss.gradient(target, pred2);
207 assert!(
208 (numerical2 - analytical2).abs() < 1e-4,
209 "under: numerical={numerical2}, analytical={analytical2}"
210 );
211 }
212
213 #[test]
214 #[should_panic(expected = "tau must be in (0, 1)")]
215 fn test_invalid_tau_zero() {
216 QuantileLoss::new(0.0);
217 }
218
219 #[test]
220 #[should_panic(expected = "tau must be in (0, 1)")]
221 fn test_invalid_tau_one() {
222 QuantileLoss::new(1.0);
223 }
224}