criterion_stats/bivariate/
regression.rs

1//! Regression analysis
2
3use float::Float;
4
5use bivariate::Data;
6
7/// A straight line that passes through the origin `y = m * x`
8#[derive(Clone, Copy)]
9pub struct Slope<A>(pub A)
10where
11    A: Float;
12
13impl<A> Slope<A>
14where
15    A: Float,
16{
17    /// Fits the data to a straight line that passes through the origin using ordinary least
18    /// squares
19    ///
20    /// - Time: `O(length)`
21    pub fn fit(data: &Data<A, A>) -> Slope<A> {
22        let xs = data.0;
23        let ys = data.1;
24
25        let xy = ::dot(xs, ys);
26        let x2 = ::dot(xs, xs);
27
28        Slope(xy / x2)
29    }
30
31    /// Computes the goodness of fit (coefficient of determination) for this data set
32    ///
33    /// - Time: `O(length)`
34    pub fn r_squared(&self, data: &Data<A, A>) -> A {
35        let _0 = A::cast(0);
36        let _1 = A::cast(1);
37        let m = self.0;
38        let xs = data.0;
39        let ys = data.1;
40
41        let n = A::cast(xs.len());
42        let y_bar = ::sum(ys) / n;
43
44        let mut ss_res = _0;
45        let mut ss_tot = _0;
46
47        for (&x, &y) in data.iter() {
48            ss_res = ss_res + (y - m * x).powi(2);
49            ss_tot = ss_res + (y - y_bar).powi(2);
50        }
51
52        _1 - ss_res / ss_tot
53    }
54}
55
56/// A straight line `y = m * x + b`
57#[derive(Clone, Copy)]
58pub struct StraightLine<A>
59where
60    A: Float,
61{
62    /// The y-intercept of the line
63    pub intercept: A,
64    /// The slope of the line
65    pub slope: A,
66}
67
68impl<A> StraightLine<A>
69where
70    A: Float,
71{
72    /// Fits the data to a straight line using ordinary least squares
73    ///
74    /// - Time: `O(length)`
75    #[cfg_attr(feature = "cargo-clippy", allow(clippy::similar_names))]
76    pub fn fit(data: Data<A, A>) -> StraightLine<A> {
77        let xs = data.0;
78        let ys = data.1;
79
80        let x2 = ::dot(xs, xs);
81        let xy = ::dot(xs, ys);
82
83        let n = A::cast(xs.len());
84        let x2_bar = x2 / n;
85        let x_bar = ::sum(xs) / n;
86        let xy_bar = xy / n;
87        let y_bar = ::sum(ys) / n;
88
89        let slope = {
90            let num = xy_bar - x_bar * y_bar;
91            let den = x2_bar - x_bar * x_bar;
92
93            num / den
94        };
95
96        let intercept = y_bar - slope * x_bar;
97
98        StraightLine { intercept, slope }
99    }
100
101    /// Computes the goodness of fit (coefficient of determination) for this data set
102    ///
103    /// - Time: `O(length)`
104    pub fn r_squared(&self, data: Data<A, A>) -> A {
105        let _0 = A::cast(0);
106        let _1 = A::cast(1);
107        let m = self.slope;
108        let b = self.intercept;
109        let xs = data.0;
110        let ys = data.1;
111
112        let n = A::cast(xs.len());
113        let y_bar = ::sum(ys) / n;
114
115        let mut ss_res = _0;
116        let mut ss_tot = _0;
117        for (&x, &y) in data.iter() {
118            ss_res = ss_res + (y - m * x - b).powi(2);
119            ss_tot = ss_tot + (y - y_bar).powi(2);
120        }
121
122        _1 - ss_res / ss_tot
123    }
124}
125
126#[cfg(test)]
127macro_rules! test {
128    ($ty:ident) => {
129        mod $ty {
130            use quickcheck::TestResult;
131
132            use bivariate::regression::StraightLine;
133            use bivariate::Data;
134
135            quickcheck! {
136                fn r_squared(size: usize, start: usize, offset: usize) -> TestResult {
137                    if let Some(x) = ::test::vec::<$ty>(size, start) {
138                        let y = ::test::vec::<$ty>(size + offset, start + offset).unwrap();
139                        let data = Data::new(&x[start..], &y[start+offset..]);
140
141                        let sl = StraightLine::fit(data);
142
143                        let r_squared = sl.r_squared(data);
144
145                        TestResult::from_bool(
146                            (r_squared > 0. || relative_eq!(r_squared, 0.)) &&
147                                (r_squared < 1. || relative_eq!(r_squared, 1.))
148                        )
149                    } else {
150                        TestResult::discard()
151                    }
152                }
153            }
154        }
155    };
156}
157
158#[cfg(test)]
159mod test {
160    test!(f32);
161    test!(f64);
162}