classifier_measures/lib.rs
1 #![warn(missing_docs)]
2
3 /*!
4Measure classifier's performance using [Receiver Operating
5Characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)
6(ROC) and [Precision-Recall](https://en.wikipedia.org/wiki/Precision_and_recall)
7(PR) curves.
8
9The curves themselves can be computed as well as trapezoidal areas under the curves.
10 */
11
12extern crate num_traits;
13
14use num_traits::{Float, NumCast};
15
16mod test;
17
18/// Integration using the trapezoidal rule.
19fn trapezoidal<F: Float>(x: &[F], y: &[F]) -> F {
20 let mut prev_x = x[0];
21 let mut prev_y = y[0];
22 let mut integral = F::zero();
23
24 for (&x, &y) in x.iter().skip(1).zip(y.iter().skip(1)) {
25 integral = integral + (x - prev_x) * (prev_y + y) / NumCast::from(2.0).unwrap();
26 prev_x = x;
27 prev_y = y;
28 }
29 integral
30}
31
32/// Checks if all data-points are finite
33fn check_data<F: Float>(v: &[(bool, F)]) -> bool {
34 v.iter().any(|x| x.0) && v.iter().any(|x| !x.0) && v.iter().all(|x| x.1.is_finite())
35}
36
37/// Uses a provided closure to convert free-form data into a standard format
38fn convert<T, X, F, I>(data: I, convert_fn: F) -> Vec<(bool, X)> where
39 I: IntoIterator<Item=T>,
40 F: Fn(T) -> (bool, X),
41 X: Float {
42 let data_it = data.into_iter();
43 let mut v = Vec::with_capacity(data_it.size_hint().0);
44 for i in data_it {
45 v.push(convert_fn(i));
46 }
47 v
48}
49
50/// Computes a ROC curve of a given classifier sorting `pairs` in-place.
51///
52/// Returns `None` if one of the classes is not present or any values are non-finite.
53/// Otherwise, returns `Some((v_x, v_y))` where `v_x` are the x-coordinates and `v_y` are the
54/// y-coordinates of the ROC curve.
55pub fn roc_mut<F: Float>(pairs: &mut [(bool, F)]) -> Option<(Vec<F>, Vec<F>)> {
56 if !check_data(pairs) {
57 return None;
58 }
59 pairs.sort_unstable_by(&|x: &(_, F), y: &(_, F)|
60 match y.1.partial_cmp(&x.1) {
61 Some(ord) => ord,
62 None => unreachable!(),
63 });
64
65 let mut s0 = F::nan();
66 let (mut tp, mut fp) = (F::zero(), F::zero());
67 let (mut tps, mut fps) = (vec![], vec![]);
68 for &(t, s) in pairs.iter() {
69 if s != s0 {
70 tps.push(tp);
71 fps.push(fp);
72 s0 = s;
73 }
74 match t {
75 false => fp = fp + F::one(),
76 true => tp = tp + F::one(),
77 }
78 }
79 tps.push(tp);
80 fps.push(fp);
81
82 // normalize
83 let (tp_max, fp_max) = (tps[tps.len() - 1], fps[fps.len() - 1]);
84 for tp in &mut tps {
85 *tp = *tp / tp_max;
86 }
87 for fp in &mut fps {
88 *fp = *fp / fp_max;
89 }
90 Some((fps, tps))
91}
92
93/// Computes a ROC curve of a given classifier.
94///
95/// `data` is a free-form `IntoIterator` object and `convert_fn` is a closure that converts each
96/// data-point into a pair `(ground_truth, prediction).`
97///
98/// Returns `None` if one of the classes is not present or any values are non-finite.
99/// Otherwise, returns `Some((v_x, v_y))` where `v_x` are the x-coordinates and `v_y` are the
100/// y-coordinates of the ROC curve.
101pub fn roc<T, X, F, I>(data: I, convert_fn: F) -> Option<(Vec<X>, Vec<X>)> where
102 I: IntoIterator<Item=T>,
103 F: Fn(T) -> (bool, X),
104 X: Float {
105 roc_mut(&mut convert(data, convert_fn))
106}
107
108/// Computes a PR curve of a given classifier.
109///
110/// `data` is a free-form `IntoIterator` object and `convert_fn` is a closure that converts each
111/// data-point into a pair `(ground_truth, prediction).`
112///
113/// Returns `None` if one of the classes is not present or any values are non-finite.
114/// Otherwise, returns `Some((v_x, v_y))` where `v_x` are the x-coordinates and `v_y` are the
115/// y-coordinates of the PR curve.
116pub fn pr<T, X, F, I>(data: I, convert_fn: F) -> Option<(Vec<X>, Vec<X>)> where
117 I: IntoIterator<Item=T>,
118 F: Fn(T) -> (bool, X),
119 X: Float {
120 pr_mut(&mut convert(data, convert_fn))
121}
122
123/// Computes a PR curve of a given classifier sorting `pairs` in-place.
124///
125/// Returns `None` if one of the classes is not present or any values are non-finite.
126/// Otherwise, returns `Some((v_x, v_y))` where `v_x` are the x-coordinates and `v_y` are the
127/// y-coordinates of the PR curve.
128pub fn pr_mut<F: Float>(pairs: &mut [(bool, F)]) -> Option<(Vec<F>, Vec<F>)> {
129 if !check_data(pairs) {
130 return None;
131 }
132 pairs.sort_unstable_by(&|x: &(_, F), y: &(_, F)|
133 match y.1.partial_cmp(&x.1) {
134 Some(ord) => ord,
135 None => unreachable!(),
136 });
137
138 let mut x0 = F::nan();
139 let (mut tp, mut p, mut fp) = (F::zero(), F::zero(), F::zero());
140 let (mut recall, mut precision) = (vec![], vec![]);
141
142 // number of labels
143 let ln = pairs.iter().fold(F::zero(), |a,b| a + if b.0 { F::one() } else { F::zero() });
144
145 for &(l, x) in pairs.iter() {
146 if x != x0 {
147 recall.push(tp / ln);
148 precision.push(if p == F::zero() { F::one() } else { tp / (tp + fp) });
149 x0 = x;
150 }
151 p = p + F::one();
152 if l { tp = tp + F::one(); }
153 else { fp = fp + F::one(); }
154 }
155 recall.push(tp / ln);
156 precision.push(tp / p);
157
158 Some((precision, recall))
159}
160
161/// Computes the area under a PR curve of a given classifier.
162///
163/// `data` is a free-form `IntoIterator` object and `convert_fn` is a closure that converts each
164/// data-point into a pair `(ground_truth, prediction).`
165///
166/// Returns `None` if one of the classes is not present or any values are non-finite.
167/// Otherwise, returns `Some(area_under_curve)`.
168pub fn pr_auc<T, X, F, I>(data: I, convert_fn: F) -> Option<X> where
169 I: IntoIterator<Item=T>,
170 F: Fn(T) -> (bool, X),
171 X: Float {
172 pr_auc_mut(&mut convert(data, convert_fn))
173}
174
175/// Computes the area under a PR curve of a given classifier sorting `pairs` in-place.
176///
177/// Returns `None` if one of the classes is not present or any values are non-finite.
178/// Otherwise, returns `Some(area_under_curve)`.
179pub fn pr_auc_mut<F: Float>(pairs: &mut [(bool, F)]) -> Option<F> {
180 pr_mut(pairs).map(|curve| {
181 trapezoidal(&curve.1, &curve.0)
182 })
183}
184
185/// Computes the area under a ROC curve of a given classifier.
186///
187/// `data` is a free-form `IntoIterator` object and `convert_fn` is a closure that converts each
188/// data-point into a pair `(ground_truth, prediction).`
189///
190/// Returns `None` if one of the classes is not present or any values are non-finite.
191/// Otherwise, returns `Some(area_under_curve)`.
192pub fn roc_auc<T, X, F, I>(data: I, convert_fn: F) -> Option<X> where
193 I: IntoIterator<Item=T>,
194 F: Fn(T) -> (bool, X),
195 X: Float {
196 roc_auc_mut(&mut convert(data, convert_fn))
197}
198
199/// Computes the area under a ROC curve of a given classifier sorting `pairs` in-place.
200///
201/// Returns `None` if one of the classes is not present or any values are non-finite.
202/// Otherwise, returns `Some(area_under_curve)`.
203pub fn roc_auc_mut<F: Float>(pairs: &mut [(bool, F)]) -> Option<F> {
204 roc_mut(pairs).map(|curve| trapezoidal(&curve.0, &curve.1))
205}