1use crate::{data::FloatData, metric::Metric};
2use serde::{Deserialize, Serialize};
3
4type ObjFn = fn(&[f64], &[f64], &[f64]) -> (Vec<f32>, Vec<f32>);
5
6#[derive(Debug, Deserialize, Serialize)]
7pub enum ObjectiveType {
8 LogLoss,
9 SquaredLoss,
10}
11
12pub fn gradient_hessian_callables(objective_type: &ObjectiveType) -> ObjFn {
13 match objective_type {
14 ObjectiveType::LogLoss => LogLoss::calc_grad_hess,
15 ObjectiveType::SquaredLoss => SquaredLoss::calc_grad_hess,
16 }
17}
18
19pub fn calc_init_callables(objective_type: &ObjectiveType) -> fn(&[f64], &[f64]) -> f64 {
20 match objective_type {
21 ObjectiveType::LogLoss => LogLoss::calc_init,
22 ObjectiveType::SquaredLoss => SquaredLoss::calc_init,
23 }
24}
25
26pub trait ObjectiveFunction {
27 fn calc_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> Vec<f32>;
28 fn calc_grad_hess(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> (Vec<f32>, Vec<f32>);
29 fn calc_init(y: &[f64], sample_weight: &[f64]) -> f64;
30 fn default_metric() -> Metric;
31}
32
33#[derive(Default)]
34pub struct LogLoss {}
35
36impl ObjectiveFunction for LogLoss {
37 #[inline]
38 fn calc_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> Vec<f32> {
39 y.iter()
40 .zip(yhat)
41 .zip(sample_weight)
42 .map(|((y_, yhat_), w_)| {
43 let yhat_ = f64::ONE / (f64::ONE + (-*yhat_).exp());
44 (-(*y_ * yhat_.ln() + (f64::ONE - *y_) * ((f64::ONE - yhat_).ln())) * *w_) as f32
45 })
46 .collect()
47 }
48
49 fn calc_init(y: &[f64], sample_weight: &[f64]) -> f64 {
50 let mut ytot: f64 = 0.;
51 let mut ntot: f64 = 0.;
52 for i in 0..y.len() {
53 ytot += sample_weight[i] * y[i];
54 ntot += sample_weight[i];
55 }
56 f64::ln(ytot / (ntot - ytot))
57 }
58
59 #[inline]
60 fn calc_grad_hess(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> (Vec<f32>, Vec<f32>) {
61 y.iter()
62 .zip(yhat)
63 .zip(sample_weight)
64 .map(|((y_, yhat_), w_)| {
65 let yhat_ = f64::ONE / (f64::ONE + (-*yhat_).exp());
66 (
67 ((yhat_ - *y_) * *w_) as f32,
68 (yhat_ * (f64::ONE - yhat_) * *w_) as f32,
69 )
70 })
71 .unzip()
72 }
73
74 fn default_metric() -> Metric {
75 Metric::LogLoss
76 }
77}
78
79#[derive(Default)]
80pub struct SquaredLoss {}
81
82impl ObjectiveFunction for SquaredLoss {
83 #[inline]
84 fn calc_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> Vec<f32> {
85 y.iter()
86 .zip(yhat)
87 .zip(sample_weight)
88 .map(|((y_, yhat_), w_)| {
89 let s = *y_ - *yhat_;
90 (s * s * *w_) as f32
91 })
92 .collect()
93 }
94
95 fn calc_init(y: &[f64], sample_weight: &[f64]) -> f64 {
96 let mut ytot: f64 = 0.;
97 let mut ntot: f64 = 0.;
98 for i in 0..y.len() {
99 ytot += sample_weight[i] * y[i];
100 ntot += sample_weight[i];
101 }
102
103 ytot / ntot
104 }
105
106 #[inline]
107 fn calc_grad_hess(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> (Vec<f32>, Vec<f32>) {
108 y.iter()
109 .zip(yhat)
110 .zip(sample_weight)
111 .map(|((y_, yhat_), w_)| (((yhat_ - *y_) * *w_) as f32, *w_ as f32))
112 .unzip()
113 }
114 fn default_metric() -> Metric {
115 Metric::RootMeanSquaredLogError
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 #[test]
123 fn test_logloss_loss() {
124 let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
125 let yhat1 = vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
126 let w = vec![1.; y.len()];
127 let l1 = LogLoss::calc_loss(&y, &yhat1, &w);
128 let yhat2 = vec![0.0, 0.0, -1.0, 1.0, 0.0, 1.0];
129 let l2 = LogLoss::calc_loss(&y, &yhat2, &w);
130 assert!(l1.iter().sum::<f32>() < l2.iter().sum::<f32>());
131 }
132
133 #[test]
134 fn test_logloss_grad() {
135 let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
136 let yhat1 = vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
137 let w = vec![1.; y.len()];
138 let (g1, _) = LogLoss::calc_grad_hess(&y, &yhat1, &w);
139 let yhat2 = vec![0.0, 0.0, -1.0, 1.0, 0.0, 1.0];
140 let (g2, _) = LogLoss::calc_grad_hess(&y, &yhat2, &w);
141 assert!(g1.iter().sum::<f32>() < g2.iter().sum::<f32>());
142 }
143
144 #[test]
145 fn test_logloss_hess() {
146 let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
147 let yhat1 = vec![-1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
148 let w = vec![1.; y.len()];
149 let (_, h1) = LogLoss::calc_grad_hess(&y, &yhat1, &w);
150 let yhat2 = vec![0.0, 0.0, -1.0, 1.0, 0.0, 1.0];
151 let (_, h2) = LogLoss::calc_grad_hess(&y, &yhat2, &w);
152 assert!(h1.iter().sum::<f32>() < h2.iter().sum::<f32>());
153 }
154
155 #[test]
156 fn test_logloss_init() {
157 let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
158 let w = vec![1.; y.len()];
159 let l1 = LogLoss::calc_init(&y, &w);
160 assert!(l1 == 0.);
161
162 let y = vec![1.0; 6];
163 let l2 = LogLoss::calc_init(&y, &w);
164 assert!(l2 == f64::INFINITY);
165
166 let y = vec![0.0; 6];
167 let l3 = LogLoss::calc_init(&y, &w);
168 assert!(l3 == f64::NEG_INFINITY);
169
170 let y = vec![0., 0., 0., 0., 1., 1.];
171 let l4 = LogLoss::calc_init(&y, &w);
172 assert!(l4 == f64::ln(2. / 4.));
173 }
174
175 #[test]
176 fn test_mse_init() {
177 let y = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
178 let w = vec![1.; y.len()];
179 let l1 = SquaredLoss::calc_init(&y, &w);
180 assert!(l1 == 0.5);
181
182 let y = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
183 let l2 = SquaredLoss::calc_init(&y, &w);
184 assert!(l2 == 1.);
185
186 let y = vec![-1.0, -1.0, -1.0, -1.0, -1.0, -1.0];
187 let l3 = SquaredLoss::calc_init(&y, &w);
188 assert!(l3 == -1.);
189
190 let y = vec![-1.0, -1.0, -1.0, 1., 1., 1.];
191 let l4 = SquaredLoss::calc_init(&y, &w);
192 assert!(l4 == 0.);
193 }
194}