Skip to main content

forust_ml/
metric.rs

1use crate::data::FloatData;
2use crate::errors::ForustError;
3use crate::utils::items_to_strings;
4use serde::{Deserialize, Serialize};
5use std::str::FromStr;
6
7pub type MetricFn = fn(&[f64], &[f64], &[f64]) -> f64;
8
9/// Compare to metric values, determining if b is better.
10/// If one of them is NaN favor the non NaN value.
11/// If both are NaN, consider the first value to be better.
12pub fn is_comparison_better(value: f64, comparison: f64, maximize: bool) -> bool {
13    match (value.is_nan(), comparison.is_nan()) {
14        // Both nan, comparison is not better,
15        // Or comparison is nan, also not better
16        (true, true) | (false, true) => false,
17        // comparison is not Nan, it's better
18        (true, false) => true,
19        // Perform numerical comparison.
20        (false, false) => {
21            // If we are maximizing is the comparison
22            // greater, than the current value
23            if maximize {
24                value < comparison
25            // If we are minimizing is the comparison
26            // less than the current value.
27            } else {
28                value > comparison
29            }
30        }
31    }
32}
33
34#[derive(Debug, Deserialize, Serialize, Clone, Copy)]
35pub enum Metric {
36    AUC,
37    LogLoss,
38    RootMeanSquaredLogError,
39    RootMeanSquaredError,
40}
41
42impl FromStr for Metric {
43    type Err = ForustError;
44
45    fn from_str(s: &str) -> Result<Self, Self::Err> {
46        match s {
47            "AUC" => Ok(Metric::AUC),
48            "LogLoss" => Ok(Metric::LogLoss),
49            "RootMeanSquaredLogError" => Ok(Metric::RootMeanSquaredLogError),
50            "RootMeanSquaredError" => Ok(Metric::RootMeanSquaredError),
51
52            _ => Err(ForustError::ParseString(
53                s.to_string(),
54                "Metric".to_string(),
55                items_to_strings(vec![
56                    "AUC",
57                    "LogLoss",
58                    "RootMeanSquaredLogError",
59                    "RootMeanSquaredError",
60                ]),
61            )),
62        }
63    }
64}
65
66pub fn metric_callables(metric_type: &Metric) -> (MetricFn, bool) {
67    match metric_type {
68        Metric::AUC => (AUCMetric::calculate_metric, AUCMetric::maximize()),
69        Metric::LogLoss => (LogLossMetric::calculate_metric, LogLossMetric::maximize()),
70        Metric::RootMeanSquaredLogError => (
71            RootMeanSquaredLogErrorMetric::calculate_metric,
72            RootMeanSquaredLogErrorMetric::maximize(),
73        ),
74        Metric::RootMeanSquaredError => (
75            RootMeanSquaredErrorMetric::calculate_metric,
76            RootMeanSquaredErrorMetric::maximize(),
77        ),
78    }
79}
80
81pub trait EvaluationMetric {
82    fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64;
83    fn maximize() -> bool;
84}
85
86pub struct LogLossMetric {}
87impl EvaluationMetric for LogLossMetric {
88    fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
89        log_loss(y, yhat, sample_weight)
90    }
91    fn maximize() -> bool {
92        false
93    }
94}
95
96pub struct AUCMetric {}
97impl EvaluationMetric for AUCMetric {
98    fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
99        roc_auc_score(y, yhat, sample_weight)
100    }
101    fn maximize() -> bool {
102        true
103    }
104}
105
106pub struct RootMeanSquaredLogErrorMetric {}
107impl EvaluationMetric for RootMeanSquaredLogErrorMetric {
108    fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
109        root_mean_squared_log_error(y, yhat, sample_weight)
110    }
111    fn maximize() -> bool {
112        false
113    }
114}
115
116pub struct RootMeanSquaredErrorMetric {}
117impl EvaluationMetric for RootMeanSquaredErrorMetric {
118    fn calculate_metric(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
119        root_mean_squared_error(y, yhat, sample_weight)
120    }
121    fn maximize() -> bool {
122        false
123    }
124}
125
126pub fn log_loss(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
127    let mut w_sum = 0.;
128    let res = y
129        .iter()
130        .zip(yhat)
131        .zip(sample_weight)
132        .map(|((y_, yhat_), w_)| {
133            w_sum += *w_;
134            let yhat_ = f64::ONE / (f64::ONE + (-*yhat_).exp());
135            -(*y_ * yhat_.ln() + (f64::ONE - *y_) * ((f64::ONE - yhat_).ln())) * *w_
136        })
137        .sum::<f64>();
138    res / w_sum
139}
140
141pub fn root_mean_squared_log_error(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
142    let mut w_sum = 0.;
143    let res = y
144        .iter()
145        .zip(yhat)
146        .zip(sample_weight)
147        .map(|((y_, yhat_), w_)| {
148            w_sum += *w_;
149            (y_.ln_1p() - yhat_.ln_1p()).powi(2) * *w_
150        })
151        .sum::<f64>();
152    (res / w_sum).sqrt()
153}
154
155pub fn root_mean_squared_error(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
156    let mut w_sum = 0.;
157    let res = y
158        .iter()
159        .zip(yhat)
160        .zip(sample_weight)
161        .map(|((y_, yhat_), w_)| {
162            w_sum += *w_;
163            (y_ - yhat_).powi(2) * *w_
164        })
165        .sum::<f64>();
166    (res / w_sum).sqrt()
167}
168
169fn trapezoid_area(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
170    (x0 - x1).abs() * (y0 + y1) * 0.5
171}
172
173pub fn roc_auc_score(y: &[f64], yhat: &[f64], sample_weight: &[f64]) -> f64 {
174    let mut indices = (0..y.len()).collect::<Vec<_>>();
175    indices.sort_unstable_by(|&a, &b| yhat[b].total_cmp(&yhat[a]));
176    let mut auc: f64 = 0.0;
177
178    let mut label = y[indices[0]];
179    let mut w = sample_weight[indices[0]];
180    let mut fp = (1.0 - label) * w;
181    let mut tp: f64 = label * w;
182    let mut tp_prev: f64 = 0.0;
183    let mut fp_prev: f64 = 0.0;
184
185    for i in 1..indices.len() {
186        if yhat[indices[i]] != yhat[indices[i - 1]] {
187            auc += trapezoid_area(fp_prev, fp, tp_prev, tp);
188            tp_prev = tp;
189            fp_prev = fp;
190        }
191        label = y[indices[i]];
192        w = sample_weight[indices[i]];
193        fp += (1.0 - label) * w;
194        tp += label * w;
195    }
196
197    auc += trapezoid_area(fp_prev, fp, tp_prev, tp);
198    if fp <= 0.0 || tp <= 0.0 {
199        auc = 0.0;
200        fp = 0.0;
201        tp = 0.0;
202    }
203
204    auc / (tp * fp)
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::utils::precision_round;
211    #[test]
212    fn test_root_mean_squared_log_error() {
213        let y = vec![1., 3., 4., 5., 2., 4., 6.];
214        let yhat = vec![3., 2., 3., 4., 4., 4., 4.];
215        let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
216        let res = root_mean_squared_log_error(&y, &yhat, &sample_weight);
217        assert_eq!(precision_round(res, 4), 0.3549);
218    }
219    #[test]
220    fn test_root_mean_squared_error() {
221        let y = vec![1., 3., 4., 5., 2., 4., 6.];
222        let yhat = vec![3., 2., 3., 4., 4., 4., 4.];
223        let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
224        let res = root_mean_squared_error(&y, &yhat, &sample_weight);
225        assert_eq!(precision_round(res, 6), 1.452966);
226    }
227
228    #[test]
229    fn test_log_loss() {
230        let y = vec![1., 0., 1., 0., 0., 0., 0.];
231        let yhat = vec![0.5, 0.01, -0., 1.05, 0., -4., 0.];
232        let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
233        let res = log_loss(&y, &yhat, &sample_weight);
234        assert_eq!(precision_round(res, 5), 0.59235);
235    }
236
237    #[test]
238    fn test_auc_real_data() {
239        let y = vec![1., 0., 1., 0., 0., 0., 0.];
240        let yhat = vec![0.5, 0.01, -0., 1.05, 0., -4., 0.];
241        let sample_weight = vec![1., 1., 1., 1., 1., 2., 2.];
242        let res = roc_auc_score(&y, &yhat, &sample_weight);
243        assert_eq!(precision_round(res, 5), 0.67857);
244    }
245
246    #[test]
247    fn test_auc_generc() {
248        let sample_weight: Vec<f64> = vec![1.; 2];
249
250        let y: Vec<f64> = vec![0., 1.];
251        let yhat: Vec<f64> = vec![0., 1.];
252        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
253        assert_eq!(auc_score, 1.);
254
255        let y: Vec<f64> = vec![0., 1.];
256        let yhat: Vec<f64> = vec![1., 0.];
257        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
258        assert_eq!(auc_score, 0.);
259
260        let y: Vec<f64> = vec![1., 0.];
261        let yhat: Vec<f64> = vec![1., 1.];
262        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
263        assert_eq!(auc_score, 0.5);
264
265        let y: Vec<f64> = vec![1., 0.];
266        let yhat: Vec<f64> = vec![1., 0.];
267        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
268        assert_eq!(auc_score, 1.0);
269
270        let y: Vec<f64> = vec![1., 0.];
271        let yhat: Vec<f64> = vec![0.5, 0.5];
272        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
273        assert_eq!(auc_score, 0.5);
274
275        let y: Vec<f64> = vec![0., 0.];
276        let yhat: Vec<f64> = vec![0.25, 0.75];
277        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
278        assert!(auc_score.is_nan());
279
280        let y: Vec<f64> = vec![1., 1.];
281        let yhat: Vec<f64> = vec![0.25, 0.75];
282        let auc_score = roc_auc_score(&y, &yhat, &sample_weight);
283        assert!(auc_score.is_nan());
284    }
285}