Skip to main content

irithyll_core/loss/
mod.rs

1//! Loss functions for gradient boosting.
2//!
3//! Each loss provides gradient and hessian computations used by the boosting
4//! loop to compute pseudo-residuals for tree fitting.
5
6pub mod expectile;
7pub mod huber;
8pub mod logistic;
9pub mod quantile;
10pub mod softmax;
11pub mod squared;
12
13// ---------------------------------------------------------------------------
14// LossType -- serialization tag for built-in loss functions
15// ---------------------------------------------------------------------------
16
17/// Tag identifying a loss function for serialization and reconstruction.
18///
19/// When saving a model, the loss type is captured so the correct loss function
20/// can be reconstructed on load. Built-in losses implement [`Loss::loss_type`]
21/// to return `Some(LossType::...)` automatically.
22///
23/// Custom losses return `None` from `loss_type()` and must be handled manually
24/// during serialization.
25#[derive(Debug, Clone, PartialEq)]
26#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
27pub enum LossType {
28    /// Squared error loss (regression).
29    Squared,
30    /// Logistic (binary cross-entropy) loss.
31    Logistic,
32    /// Huber loss with the given delta threshold.
33    Huber {
34        /// Threshold at which loss transitions from quadratic to linear.
35        delta: f64,
36    },
37    /// Softmax (multi-class cross-entropy) loss.
38    Softmax {
39        /// Number of classes.
40        n_classes: usize,
41    },
42    /// Expectile loss with asymmetry parameter.
43    Expectile {
44        /// Asymmetry parameter tau in (0, 1).
45        tau: f64,
46    },
47    /// Quantile (pinball) loss with target quantile.
48    Quantile {
49        /// Target quantile tau in (0, 1).
50        tau: f64,
51    },
52}
53
54#[cfg(feature = "alloc")]
55impl LossType {
56    /// Reconstruct the loss function from its serialized tag.
57    pub fn into_loss(self) -> alloc::boxed::Box<dyn Loss> {
58        match self {
59            LossType::Squared => alloc::boxed::Box::new(squared::SquaredLoss),
60            LossType::Logistic => alloc::boxed::Box::new(logistic::LogisticLoss),
61            LossType::Huber { delta } => alloc::boxed::Box::new(huber::HuberLoss { delta }),
62            LossType::Softmax { n_classes } => {
63                alloc::boxed::Box::new(softmax::SoftmaxLoss { n_classes })
64            }
65            LossType::Expectile { tau } => {
66                alloc::boxed::Box::new(expectile::ExpectileLoss::new(tau))
67            }
68            LossType::Quantile { tau } => alloc::boxed::Box::new(quantile::QuantileLoss::new(tau)),
69        }
70    }
71}
72
73// ---------------------------------------------------------------------------
74// Loss trait
75// ---------------------------------------------------------------------------
76
77/// A differentiable loss function for gradient boosting.
78///
79/// Implementations must provide first and second derivatives (gradient/hessian)
80/// with respect to the prediction, which drive the boosting updates.
81pub trait Loss: Send + Sync + 'static {
82    /// Number of output dimensions (1 for regression/binary, C for multiclass).
83    fn n_outputs(&self) -> usize;
84
85    /// First derivative of the loss with respect to prediction: dL/df.
86    fn gradient(&self, target: f64, prediction: f64) -> f64;
87
88    /// Second derivative of the loss with respect to prediction: d^2L/df^2.
89    fn hessian(&self, target: f64, prediction: f64) -> f64;
90
91    /// Raw loss value L(target, prediction).
92    fn loss(&self, target: f64, prediction: f64) -> f64;
93
94    /// Transform raw model output to final prediction (identity for regression, sigmoid for binary).
95    fn predict_transform(&self, raw: f64) -> f64;
96
97    /// Initial constant prediction (before any trees). Typically the optimal constant
98    /// that minimizes sum of losses over the given targets.
99    fn initial_prediction(&self, _targets: &[f64]) -> f64 {
100        0.0
101    }
102
103    /// Return the serialization tag for this loss function.
104    ///
105    /// Built-in losses return `Some(LossType::...)`. Custom losses default to
106    /// `None`, which means the model cannot be auto-serialized.
107    fn loss_type(&self) -> Option<LossType> {
108        None
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Box<dyn Loss> blanket impl -- enables DynSGBT = SGBT<Box<dyn Loss>>
114// ---------------------------------------------------------------------------
115
116#[cfg(feature = "alloc")]
117impl Loss for alloc::boxed::Box<dyn Loss> {
118    #[inline]
119    fn n_outputs(&self) -> usize {
120        (**self).n_outputs()
121    }
122
123    #[inline]
124    fn gradient(&self, target: f64, prediction: f64) -> f64 {
125        (**self).gradient(target, prediction)
126    }
127
128    #[inline]
129    fn hessian(&self, target: f64, prediction: f64) -> f64 {
130        (**self).hessian(target, prediction)
131    }
132
133    #[inline]
134    fn loss(&self, target: f64, prediction: f64) -> f64 {
135        (**self).loss(target, prediction)
136    }
137
138    #[inline]
139    fn predict_transform(&self, raw: f64) -> f64 {
140        (**self).predict_transform(raw)
141    }
142
143    fn initial_prediction(&self, targets: &[f64]) -> f64 {
144        (**self).initial_prediction(targets)
145    }
146
147    fn loss_type(&self) -> Option<LossType> {
148        (**self).loss_type()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_loss_type_variants() {
158        let sq = LossType::Squared;
159        assert_eq!(sq, LossType::Squared);
160
161        let hub = LossType::Huber { delta: 1.5 };
162        assert_eq!(hub, LossType::Huber { delta: 1.5 });
163
164        let soft = LossType::Softmax { n_classes: 3 };
165        assert_eq!(soft, LossType::Softmax { n_classes: 3 });
166
167        let exp = LossType::Expectile { tau: 0.9 };
168        assert_eq!(exp, LossType::Expectile { tau: 0.9 });
169
170        let quant = LossType::Quantile { tau: 0.5 };
171        assert_eq!(quant, LossType::Quantile { tau: 0.5 });
172    }
173
174    #[cfg(feature = "alloc")]
175    #[test]
176    fn test_into_loss_squared() {
177        let loss = LossType::Squared.into_loss();
178        assert_eq!(loss.n_outputs(), 1);
179        assert!((loss.gradient(3.0, 5.0) - 2.0).abs() < 1e-12);
180    }
181
182    #[cfg(feature = "alloc")]
183    #[test]
184    fn test_into_loss_logistic() {
185        let loss = LossType::Logistic.into_loss();
186        assert_eq!(loss.n_outputs(), 1);
187    }
188
189    #[cfg(feature = "alloc")]
190    #[test]
191    fn test_into_loss_huber() {
192        let loss = LossType::Huber { delta: 1.0 }.into_loss();
193        assert_eq!(loss.n_outputs(), 1);
194    }
195
196    #[cfg(feature = "alloc")]
197    #[test]
198    fn test_into_loss_softmax() {
199        let loss = LossType::Softmax { n_classes: 5 }.into_loss();
200        assert_eq!(loss.n_outputs(), 5);
201    }
202
203    #[cfg(feature = "alloc")]
204    #[test]
205    fn test_into_loss_expectile() {
206        let loss = LossType::Expectile { tau: 0.9 }.into_loss();
207        assert_eq!(loss.n_outputs(), 1);
208    }
209
210    #[cfg(feature = "alloc")]
211    #[test]
212    fn test_into_loss_quantile() {
213        let loss = LossType::Quantile { tau: 0.5 }.into_loss();
214        assert_eq!(loss.n_outputs(), 1);
215    }
216
217    #[cfg(feature = "alloc")]
218    #[test]
219    fn test_box_dyn_loss_blanket() {
220        let loss: alloc::boxed::Box<dyn Loss> = alloc::boxed::Box::new(squared::SquaredLoss);
221        assert_eq!(loss.n_outputs(), 1);
222        assert!((loss.gradient(3.0, 5.0) - 2.0).abs() < 1e-12);
223        assert!((loss.hessian(3.0, 5.0) - 1.0).abs() < 1e-12);
224        assert!((loss.loss(1.0, 3.0) - 2.0).abs() < 1e-12);
225        assert!((loss.predict_transform(42.0) - 42.0).abs() < 1e-12);
226        assert!(loss.loss_type().is_some());
227    }
228}