criterion_stats/bivariate/
regression.rs1use float::Float;
4
5use bivariate::Data;
6
7#[derive(Clone, Copy)]
9pub struct Slope<A>(pub A)
10where
11 A: Float;
12
13impl<A> Slope<A>
14where
15 A: Float,
16{
17 pub fn fit(data: &Data<A, A>) -> Slope<A> {
22 let xs = data.0;
23 let ys = data.1;
24
25 let xy = ::dot(xs, ys);
26 let x2 = ::dot(xs, xs);
27
28 Slope(xy / x2)
29 }
30
31 pub fn r_squared(&self, data: &Data<A, A>) -> A {
35 let _0 = A::cast(0);
36 let _1 = A::cast(1);
37 let m = self.0;
38 let xs = data.0;
39 let ys = data.1;
40
41 let n = A::cast(xs.len());
42 let y_bar = ::sum(ys) / n;
43
44 let mut ss_res = _0;
45 let mut ss_tot = _0;
46
47 for (&x, &y) in data.iter() {
48 ss_res = ss_res + (y - m * x).powi(2);
49 ss_tot = ss_res + (y - y_bar).powi(2);
50 }
51
52 _1 - ss_res / ss_tot
53 }
54}
55
56#[derive(Clone, Copy)]
58pub struct StraightLine<A>
59where
60 A: Float,
61{
62 pub intercept: A,
64 pub slope: A,
66}
67
68impl<A> StraightLine<A>
69where
70 A: Float,
71{
72 #[cfg_attr(feature = "cargo-clippy", allow(clippy::similar_names))]
76 pub fn fit(data: Data<A, A>) -> StraightLine<A> {
77 let xs = data.0;
78 let ys = data.1;
79
80 let x2 = ::dot(xs, xs);
81 let xy = ::dot(xs, ys);
82
83 let n = A::cast(xs.len());
84 let x2_bar = x2 / n;
85 let x_bar = ::sum(xs) / n;
86 let xy_bar = xy / n;
87 let y_bar = ::sum(ys) / n;
88
89 let slope = {
90 let num = xy_bar - x_bar * y_bar;
91 let den = x2_bar - x_bar * x_bar;
92
93 num / den
94 };
95
96 let intercept = y_bar - slope * x_bar;
97
98 StraightLine { intercept, slope }
99 }
100
101 pub fn r_squared(&self, data: Data<A, A>) -> A {
105 let _0 = A::cast(0);
106 let _1 = A::cast(1);
107 let m = self.slope;
108 let b = self.intercept;
109 let xs = data.0;
110 let ys = data.1;
111
112 let n = A::cast(xs.len());
113 let y_bar = ::sum(ys) / n;
114
115 let mut ss_res = _0;
116 let mut ss_tot = _0;
117 for (&x, &y) in data.iter() {
118 ss_res = ss_res + (y - m * x - b).powi(2);
119 ss_tot = ss_tot + (y - y_bar).powi(2);
120 }
121
122 _1 - ss_res / ss_tot
123 }
124}
125
126#[cfg(test)]
127macro_rules! test {
128 ($ty:ident) => {
129 mod $ty {
130 use quickcheck::TestResult;
131
132 use bivariate::regression::StraightLine;
133 use bivariate::Data;
134
135 quickcheck! {
136 fn r_squared(size: usize, start: usize, offset: usize) -> TestResult {
137 if let Some(x) = ::test::vec::<$ty>(size, start) {
138 let y = ::test::vec::<$ty>(size + offset, start + offset).unwrap();
139 let data = Data::new(&x[start..], &y[start+offset..]);
140
141 let sl = StraightLine::fit(data);
142
143 let r_squared = sl.r_squared(data);
144
145 TestResult::from_bool(
146 (r_squared > 0. || relative_eq!(r_squared, 0.)) &&
147 (r_squared < 1. || relative_eq!(r_squared, 1.))
148 )
149 } else {
150 TestResult::discard()
151 }
152 }
153 }
154 }
155 };
156}
157
158#[cfg(test)]
159mod test {
160 test!(f32);
161 test!(f64);
162}