libcaliph/fit.rs
1/* This Source Code Form is subject to the terms of the Mozilla Public
2License, v. 2.0. If a copy of the MPL was not distributed with this
3file, You can obtain one at https://mozilla.org/MPL/2.0/.
4Copyright 2021 Peter Dunne */
5
6//! Provides methods to performs linear fit of two input arrays, x and y
7use super::stats;
8
9/// Linear regression of x,y data
10/// Returns an array of [slope, offset]
11/// ```
12/// use crate::libcaliph::fit::fit;
13/// use float_cmp::approx_eq;
14/// let x = [1.0, 2.0, 3.0];
15/// let y = [3.0, 5.0, 7.0];
16///
17/// let result = fit(&x, &y);
18/// let comparison = [2.0, 1.0];
19///
20/// assert!(approx_eq!(f64, result[0], comparison[0]) && approx_eq!(f64, result[0], comparison[0]) );
21/// ```
22///
23pub fn fit(x: &[f64], y: &[f64]) -> [f64; 2] {
24 let slope = stats::covariance(x, y) / stats::variance(x);
25 let intercept = stats::mean(y) - slope * stats::mean(x);
26 [slope, intercept]
27}
28
29/// Gives predicted value using `model` for a given x
30/// ```
31/// use crate::libcaliph::fit::predict;
32/// use float_cmp::approx_eq;
33///
34/// let model = [2.0, 3.5];
35/// let x = 1.5;
36/// let result = predict(&x, &model);
37/// let comparison = 6.5;
38/// assert!(approx_eq!(f64, result, comparison));
39///```
40pub fn predict(x: &f64, model: &[f64; 2]) -> f64 {
41 x * model[0] + model[1]
42}
43
44/// Calculates RMS for a model
45fn root_mean_squared_error(actual: &[f64], predicted: &[f64]) -> f64 {
46 let length = actual.len();
47
48 let sum_error_iter: f64 = predicted
49 .iter()
50 .zip(actual.iter())
51 .map(|(x, y)| (x - y).powi(2))
52 .sum();
53
54 (sum_error_iter / length as f64).sqrt()
55}
56
57fn rsquared(y: &[f64], rms: &f64) -> f64 {
58 1.0 - (rms / stats::variance(y))
59}
60
61/// Evaluates all data in a model, returning the root mean squared error (RMSE), and the R-Squared goodness of fit
62///
63/// ```
64/// use crate::libcaliph::fit::{fit,evaluate};
65/// use float_cmp::approx_eq;
66///
67/// let x = [1.05, 1.992, 3.03];
68/// let y = [2.993, 4.92, 6.99];
69///
70/// let model = fit(&x, &y);
71/// let result = evaluate(&x, &y, &model);
72///
73/// let comparison = [1.19675583971723e-2, 0.99550];
74/// println!("{:.e}", result[0]);
75/// println!("{}", result[1]);
76/// assert!(approx_eq!(f64, result[0], comparison[0]) && approx_eq!(f64, result[0], comparison[0]) );
77///```
78pub fn evaluate(x: &[f64], y: &[f64], model: &[f64; 2]) -> [f64; 2] {
79 let y_predicted: Vec<f64> = x.iter().map(|y| predict(y, model)).collect();
80 let rms = root_mean_squared_error(y, &y_predicted);
81 [rms, rsquared(y, &rms)]
82}