forust_ml/
objective.rs

1use crate::{data::FloatData, metric::Metric};
2use serde::{Deserialize, Serialize};
3
4type ObjFn = fn(&[f64], &[f64], &[f64]) -> (Vec<f32>, Vec<f32>);
5
6#[derive(Debug, Deserialize, Serialize)]
7pub enum ObjectiveType {
8    LogLoss,
9    SquaredLoss,
10}
11
12pub fn gradient_hessian_callables(objective_type: &ObjectiveType) -> ObjFn {
13    match objective_type {
14        ObjectiveType::LogLoss => LogLoss::calc_grad_hess,
15        ObjectiveType::SquaredLoss => SquaredLoss::calc_grad_hess,
16    }
17}
18
19pub fn calc_init_callables(objective_type: &ObjectiveType) -> fn(&[f64], &[f64]) -> f64 {
20    match objective_type {
21        ObjectiveType::LogLoss => LogLoss::calc_init,
22        ObjectiveType::SquaredLoss => SquaredLoss::calc_init,
23    }
24}
25
26pub trait ObjectiveFunction {
27    fn calc_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> Vec<f32>;
28    fn calc_grad_hess(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> (Vec<f32>, Vec<f32>);
29    fn calc_init(y: &[f64], sample_weight: &[f64]) -> f64;
30    fn default_metric() -> Metric;
31}
32
33#[derive(Default)]
34pub struct LogLoss {}
35
36impl ObjectiveFunction for LogLoss {
37    #[inline]
38    fn calc_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> Vec<f32> {
39        y.iter()
40            .zip(yhat)
41            .zip(sample_weight)
42            .map(|((y_, yhat_), w_)| {
43                let yhat_ = f64::ONE / (f64::ONE + (-*yhat_).exp());
44                (-(*y_ * yhat_.ln() + (f64::ONE - *y_) * ((f64::ONE - yhat_).ln())) * *w_) as f32
45            })
46            .collect()
47    }
48
49    fn calc_init(y: &[f64], sample_weight: &[f64]) -> f64 {
50        let mut ytot: f64 = 0.;
51        let mut ntot: f64 = 0.;
52        for i in 0..y.len() {
53            ytot += sample_weight[i] * y[i];
54            ntot += sample_weight[i];
55        }
56        f64::ln(ytot / (ntot - ytot))
57    }
58
59    #[inline]
60    fn calc_grad_hess(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> (Vec<f32>, Vec<f32>) {
61        y.iter()
62            .zip(yhat)
63            .zip(sample_weight)
64            .map(|((y_, yhat_), w_)| {
65                let yhat_ = f64::ONE / (f64::ONE + (-*yhat_).exp());
66                (
67                    ((yhat_ - *y_) * *w_) as f32,
68                    (yhat_ * (f64::ONE - yhat_) * *w_) as f32,
69                )
70            })
71            .unzip()
72    }
73
74    fn default_metric() -> Metric {
75        Metric::LogLoss
76    }
77}
78
79#[derive(Default)]
80pub struct SquaredLoss {}
81
82impl ObjectiveFunction for SquaredLoss {
83    #[inline]
84    fn calc_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> Vec<f32> {
85        y.iter()
86            .zip(yhat)
87            .zip(sample_weight)
88            .map(|((y_, yhat_), w_)| {
89                let s = *y_ - *yhat_;
90                (s * s * *w_) as f32
91            })
92            .collect()
93    }
94
95    fn calc_init(y: &[f64], sample_weight: &[f64]) -> f64 {
96        let mut ytot: f64 = 0.;
97        let mut ntot: f64 = 0.;
98        for i in 0..y.len() {
99            ytot += sample_weight[i] * y[i];
100            ntot += sample_weight[i];
101        }
102
103        ytot / ntot
104    }
105
106    #[inline]
107    fn calc_grad_hess(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> (Vec<f32>, Vec<f32>) {
108        y.iter()
109            .zip(yhat)
110            .zip(sample_weight)
111            .map(|((y_, yhat_), w_)| (((yhat_ - *y_) * *w_) as f32, *w_ as f32))
112            .unzip()
113    }
114    fn default_metric() -> Metric {
115        Metric::RootMeanSquaredLogError
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    #[test]
123    fn test_logloss_loss() {
124        let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
125        let yhat1 = vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
126        let w = vec![1.; y.len()];
127        let l1 = LogLoss::calc_loss(&y, &yhat1, &w);
128        let yhat2 = vec![0.0, 0.0, -1.0, 1.0, 0.0, 1.0];
129        let l2 = LogLoss::calc_loss(&y, &yhat2, &w);
130        assert!(l1.iter().sum::<f32>() < l2.iter().sum::<f32>());
131    }
132
133    #[test]
134    fn test_logloss_grad() {
135        let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
136        let yhat1 = vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
137        let w = vec![1.; y.len()];
138        let (g1, _) = LogLoss::calc_grad_hess(&y, &yhat1, &w);
139        let yhat2 = vec![0.0, 0.0, -1.0, 1.0, 0.0, 1.0];
140        let (g2, _) = LogLoss::calc_grad_hess(&y, &yhat2, &w);
141        assert!(g1.iter().sum::<f32>() < g2.iter().sum::<f32>());
142    }
143
144    #[test]
145    fn test_logloss_hess() {
146        let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
147        let yhat1 = vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
148        let w = vec![1.; y.len()];
149        let (_, h1) = LogLoss::calc_grad_hess(&y, &yhat1, &w);
150        let yhat2 = vec![0.0, 0.0, -1.0, 1.0, 0.0, 1.0];
151        let (_, h2) = LogLoss::calc_grad_hess(&y, &yhat2, &w);
152        assert!(h1.iter().sum::<f32>() < h2.iter().sum::<f32>());
153    }
154
155    #[test]
156    fn test_logloss_init() {
157        let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
158        let w = vec![1.; y.len()];
159        let l1 = LogLoss::calc_init(&y, &w);
160        assert!(l1 == 0.);
161
162        let y = vec![1.0; 6];
163        let l2 = LogLoss::calc_init(&y, &w);
164        assert!(l2 == f64::INFINITY);
165
166        let y = vec![0.0; 6];
167        let l3 = LogLoss::calc_init(&y, &w);
168        assert!(l3 == f64::NEG_INFINITY);
169
170        let y = vec![0., 0., 0., 0., 1., 1.];
171        let l4 = LogLoss::calc_init(&y, &w);
172        assert!(l4 == f64::ln(2. / 4.));
173    }
174
175    #[test]
176    fn test_mse_init() {
177        let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
178        let w = vec![1.; y.len()];
179        let l1 = SquaredLoss::calc_init(&y, &w);
180        assert!(l1 == 0.5);
181
182        let y = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
183        let l2 = SquaredLoss::calc_init(&y, &w);
184        assert!(l2 == 1.);
185
186        let y = vec![-1.0, -1.0, -1.0, -1.0, -1.0, -1.0];
187        let l3 = SquaredLoss::calc_init(&y, &w);
188        assert!(l3 == -1.);
189
190        let y = vec![-1.0, -1.0, -1.0, 1., 1., 1.];
191        let l4 = SquaredLoss::calc_init(&y, &w);
192        assert!(l4 == 0.);
193    }
194}