1use crate::data::FloatData;
2use crate::errors::ForustError;
3use crate::utils::items_to_strings;
4use serde::{Deserialize, Serialize};
5use std::str::FromStr;
6
7pub type MetricFn = fn(&[f64], &[f64], &[f64]) -> f64;
8
9pub fn is_comparison_better(value: f64, comparison: f64, maximize: bool) -> bool {
13 match (value.is_nan(), comparison.is_nan()) {
14 (true, true) | (false, true) => false,
17 (true, false) => true,
19 (false, false) => {
21 if maximize {
24 value < comparison
25 } else {
28 value > comparison
29 }
30 }
31 }
32}
33
34#[derive(Debug, Deserialize, Serialize, Clone, Copy)]
35pub enum Metric {
36 AUC,
37 LogLoss,
38 RootMeanSquaredLogError,
39 RootMeanSquaredError,
40}
41
42impl FromStr for Metric {
43 type Err = ForustError;
44
45 fn from_str(s: &str) -> Result<Self, Self::Err> {
46 match s {
47 "AUC" => Ok(Metric::AUC),
48 "LogLoss" => Ok(Metric::LogLoss),
49 "RootMeanSquaredLogError" => Ok(Metric::RootMeanSquaredLogError),
50 "RootMeanSquaredError" => Ok(Metric::RootMeanSquaredError),
51
52 _ => Err(ForustError::ParseString(
53 s.to_string(),
54 "Metric".to_string(),
55 items_to_strings(vec![
56 "AUC",
57 "LogLoss",
58 "RootMeanSquaredLogError",
59 "RootMeanSquaredError",
60 ]),
61 )),
62 }
63 }
64}
65
66pub fn metric_callables(metric_type: &Metric) -> (MetricFn, bool) {
67 match metric_type {
68 Metric::AUC => (AUCMetric::calculate_metric, AUCMetric::maximize()),
69 Metric::LogLoss => (LogLossMetric::calculate_metric, LogLossMetric::maximize()),
70 Metric::RootMeanSquaredLogError => (
71 RootMeanSquaredLogErrorMetric::calculate_metric,
72 RootMeanSquaredLogErrorMetric::maximize(),
73 ),
74 Metric::RootMeanSquaredError => (
75 RootMeanSquaredErrorMetric::calculate_metric,
76 RootMeanSquaredErrorMetric::maximize(),
77 ),
78 }
79}
80
81pub trait EvaluationMetric {
82 fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64;
83 fn maximize() -> bool;
84}
85
86pub struct LogLossMetric {}
87impl EvaluationMetric for LogLossMetric {
88 fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
89 log_loss(y, yhat, sample_weight)
90 }
91 fn maximize() -> bool {
92 false
93 }
94}
95
96pub struct AUCMetric {}
97impl EvaluationMetric for AUCMetric {
98 fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
99 roc_auc_score(y, yhat, sample_weight)
100 }
101 fn maximize() -> bool {
102 true
103 }
104}
105
106pub struct RootMeanSquaredLogErrorMetric {}
107impl EvaluationMetric for RootMeanSquaredLogErrorMetric {
108 fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
109 root_mean_squared_log_error(y, yhat, sample_weight)
110 }
111 fn maximize() -> bool {
112 false
113 }
114}
115
116pub struct RootMeanSquaredErrorMetric {}
117impl EvaluationMetric for RootMeanSquaredErrorMetric {
118 fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
119 root_mean_squared_error(y, yhat, sample_weight)
120 }
121 fn maximize() -> bool {
122 false
123 }
124}
125
126pub fn log_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
127 let mut w_sum = 0.;
128 let res = y
129 .iter()
130 .zip(yhat)
131 .zip(sample_weight)
132 .map(|((y_, yhat_), w_)| {
133 w_sum += *w_;
134 let yhat_ = f64::ONE / (f64::ONE + (-*yhat_).exp());
135 -(*y_ * yhat_.ln() + (f64::ONE - *y_) * ((f64::ONE - yhat_).ln())) * *w_
136 })
137 .sum::<f64>();
138 res / w_sum
139}
140
141pub fn root_mean_squared_log_error(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
142 let mut w_sum = 0.;
143 let res = y
144 .iter()
145 .zip(yhat)
146 .zip(sample_weight)
147 .map(|((y_, yhat_), w_)| {
148 w_sum += *w_;
149 (y_.ln_1p() - yhat_.ln_1p()).powi(2) * *w_
150 })
151 .sum::<f64>();
152 (res / w_sum).sqrt()
153}
154
155pub fn root_mean_squared_error(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
156 let mut w_sum = 0.;
157 let res = y
158 .iter()
159 .zip(yhat)
160 .zip(sample_weight)
161 .map(|((y_, yhat_), w_)| {
162 w_sum += *w_;
163 (y_ - yhat_).powi(2) * *w_
164 })
165 .sum::<f64>();
166 (res / w_sum).sqrt()
167}
168
169fn trapezoid_area(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
170 (x0 - x1).abs() * (y0 + y1) * 0.5
171}
172
173pub fn roc_auc_score(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
174 let mut indices = (0..y.len()).collect::<Vec<_>>();
175 indices.sort_unstable_by(|&a, &b| yhat[b].total_cmp(&yhat[a]));
176 let mut auc: f64 = 0.0;
177
178 let mut label = y[indices[0]];
179 let mut w = sample_weight[indices[0]];
180 let mut fp = (1.0 - label) * w;
181 let mut tp: f64 = label * w;
182 let mut tp_prev: f64 = 0.0;
183 let mut fp_prev: f64 = 0.0;
184
185 for i in 1..indices.len() {
186 if yhat[indices[i]] != yhat[indices[i - 1]] {
187 auc += trapezoid_area(fp_prev, fp, tp_prev, tp);
188 tp_prev = tp;
189 fp_prev = fp;
190 }
191 label = y[indices[i]];
192 w = sample_weight[indices[i]];
193 fp += (1.0 - label) * w;
194 tp += label * w;
195 }
196
197 auc += trapezoid_area(fp_prev, fp, tp_prev, tp);
198 if fp <= 0.0 || tp <= 0.0 {
199 auc = 0.0;
200 fp = 0.0;
201 tp = 0.0;
202 }
203
204 auc / (tp * fp)
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::utils::precision_round;
211 #[test]
212 fn test_root_mean_squared_log_error() {
213 let y = vec![1., 3., 4., 5., 2., 4., 6.];
214 let yhat = vec![3., 2., 3., 4., 4., 4., 4.];
215 let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
216 let res = root_mean_squared_log_error(&y, &yhat, &sample_weight);
217 assert_eq!(precision_round(res, 4), 0.3549);
218 }
219 #[test]
220 fn test_root_mean_squared_error() {
221 let y = vec![1., 3., 4., 5., 2., 4., 6.];
222 let yhat = vec![3., 2., 3., 4., 4., 4., 4.];
223 let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
224 let res = root_mean_squared_error(&y, &yhat, &sample_weight);
225 assert_eq!(precision_round(res, 6), 1.452966);
226 }
227
228 #[test]
229 fn test_log_loss() {
230 let y = vec![1., 0., 1., 0., 0., 0., 0.];
231 let yhat = vec![0.5, 0.01, -0., 1.05, 0., -4., 0.];
232 let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
233 let res = log_loss(&y, &yhat, &sample_weight);
234 assert_eq!(precision_round(res, 5), 0.59235);
235 }
236
237 #[test]
238 fn test_auc_real_data() {
239 let y = vec![1., 0., 1., 0., 0., 0., 0.];
240 let yhat = vec![0.5, 0.01, -0., 1.05, 0., -4., 0.];
241 let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
242 let res = roc_auc_score(&y, &yhat, &sample_weight);
243 assert_eq!(precision_round(res, 5), 0.67857);
244 }
245
246 #[test]
247 fn test_auc_generc() {
248 let sample_weight: Vec<f64> = vec![1.; 2];
249
250 let y: Vec<f64> = vec![0., 1.];
251 let yhat: Vec<f64> = vec![0., 1.];
252 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
253 assert_eq!(auc_score, 1.);
254
255 let y: Vec<f64> = vec![0., 1.];
256 let yhat: Vec<f64> = vec![1., 0.];
257 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
258 assert_eq!(auc_score, 0.);
259
260 let y: Vec<f64> = vec![1., 0.];
261 let yhat: Vec<f64> = vec![1., 1.];
262 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
263 assert_eq!(auc_score, 0.5);
264
265 let y: Vec<f64> = vec![1., 0.];
266 let yhat: Vec<f64> = vec![1., 0.];
267 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
268 assert_eq!(auc_score, 1.0);
269
270 let y: Vec<f64> = vec![1., 0.];
271 let yhat: Vec<f64> = vec![0.5, 0.5];
272 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
273 assert_eq!(auc_score, 0.5);
274
275 let y: Vec<f64> = vec![0., 0.];
276 let yhat: Vec<f64> = vec![0.25, 0.75];
277 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
278 assert!(auc_score.is_nan());
279
280 let y: Vec<f64> = vec![1., 1.];
281 let yhat: Vec<f64> = vec![0.25, 0.75];
282 let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
283 assert!(auc_score.is_nan());
284 }
285}