1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use crate::utils::log_eta;
use crate::Control;
use ndarray::{s, stack, Array1, Array2, Axis};
pub trait Classifier {
fn n(&self) -> usize;
fn predict(&self, start: usize, stop: usize, split: usize) -> Array1<f64>;
fn single_likelihood(
&self,
predictions: &Array1<f64>,
start: usize,
stop: usize,
split: usize,
) -> f64 {
if (stop - split <= 1) || (split - start <= 1) {
return 0.;
}
let (left, right) = predictions.slice(s![..]).split_at(Axis(0), split - start);
let left_correction = ((stop - start - 1) as f64) / ((split - start - 1) as f64);
let right_correction = ((stop - start - 1) as f64) / ((stop - split - 1) as f64);
left.mapv(|x| log_eta((1. - x) * left_correction)).sum()
+ right.mapv(|x| log_eta(x * right_correction)).sum()
}
fn full_likelihood(
&self,
predictions: &Array1<f64>,
start: usize,
stop: usize,
split: usize,
) -> Array2<f64> {
if (stop - split <= 1) || (split - start <= 1) {
return Array2::zeros((2, stop - start));
}
let mut likelihoods = stack(Axis(0), &[predictions.view(), predictions.view()]).unwrap();
assert!(likelihoods.shape() == [2, stop - start]);
let prior_00 = ((stop - start - 1) as f64) / ((split - start - 1) as f64);
let prior_01 = ((stop - start - 1) as f64) / ((split - start) as f64);
let prior_10 = ((stop - start - 1) as f64) / ((stop - split) as f64);
let prior_11 = ((stop - start - 1) as f64) / ((stop - split - 1) as f64);
likelihoods
.slice_mut(s![0, ..(split - start)])
.mapv_inplace(|x| log_eta((1. - x) * prior_00));
likelihoods
.slice_mut(s![0, (split - start)..])
.mapv_inplace(|x| log_eta((1. - x) * prior_01));
likelihoods
.slice_mut(s![1, ..(split - start)])
.mapv_inplace(|x| log_eta(x * prior_10));
likelihoods
.slice_mut(s![1, (split - start)..])
.mapv_inplace(|x| log_eta(x * prior_11));
likelihoods
}
fn control(&self) -> &Control;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::TrivialClassifier;
use assert_approx_eq::*;
#[test]
fn test_single_likelihood() {
let control = Control::default();
let classifier = TrivialClassifier {
n: 10,
control: &control,
};
let predictions = classifier.predict(0, 10, 5);
assert_approx_eq!(
classifier.single_likelihood(&predictions, 0, 10, 5),
0.809552182
);
}
#[test]
fn test_full_likelihood() {
let control = Control::default();
let classifier = TrivialClassifier {
n: 10,
control: &control,
};
let predictions = classifier.predict(0, 10, 5);
let mut expected = Array2::<f64>::zeros((2, 10));
expected[[0, 0]] = 0.8095521826214339;
expected[[1, 0]] = -6.0;
assert_eq!(classifier.full_likelihood(&predictions, 0, 10, 5), expected);
}
}