1use ndarray::Array2;
9
10fn median(v: &mut [f64]) -> f64 {
11 v.sort_by(|a, b| a.partial_cmp(b).unwrap());
12 let n = v.len();
13 if n % 2 == 1 {
14 v[n / 2]
15 } else {
16 (v[n / 2 - 1] + v[n / 2]) / 2.0
17 }
18}
19
20fn pexp(q: f64, rate: f64) -> f64 {
22 if q < 0.0 {
23 0.0
24 } else {
25 1.0 - (-rate * q).exp()
26 }
27}
28
29pub fn propexpr(x: &Array2<f64>, neg_x: &Array2<f64>) -> Vec<f64> {
35 let narrays = x.ncols();
36 assert_eq!(
37 neg_x.ncols(),
38 narrays,
39 "x and neg_x must have equal columns"
40 );
41
42 let mut out = vec![0.0_f64; narrays];
43 for (i, o) in out.iter_mut().enumerate() {
44 let mut b: Vec<f64> = neg_x
45 .column(i)
46 .iter()
47 .copied()
48 .filter(|v| !v.is_nan())
49 .collect();
50 let r: Vec<f64> = x
51 .column(i)
52 .iter()
53 .copied()
54 .filter(|v| !v.is_nan())
55 .collect();
56 let nb = b.len() as f64;
57 let nr = r.len() as f64;
58
59 let mu = b.iter().sum::<f64>() / nb;
60 let alpha = (r.iter().sum::<f64>() / nr - mu).max(10.0);
61 let b1 = median(&mut b);
62
63 let p1 = b.iter().map(|&bj| pexp(b1 - bj, 1.0 / alpha)).sum::<f64>() / nb;
64 let pb = (b.iter().filter(|&&v| v < b1).count() as f64
65 + b.iter().filter(|&&v| v == b1).count() as f64 / 2.0)
66 / nb;
67 let p = (r.iter().filter(|&&v| v < b1).count() as f64
68 + r.iter().filter(|&&v| v == b1).count() as f64 / 2.0)
69 / nr;
70
71 *o = ((pb - p) / (pb - p1)).clamp(0.0, 1.0);
72 }
73 out
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use ndarray::array;
80
81 fn close(a: &[f64], b: &[f64], tol: f64) -> bool {
82 a.len() == b.len()
83 && a.iter()
84 .zip(b)
85 .all(|(&x, &y)| (x - y).abs() <= tol + tol * y.abs())
86 }
87
88 #[test]
89 fn two_arrays_match_r() {
90 let neg = array![
93 [4.0, 3.0],
94 [5.0, 4.0],
95 [5.0, 5.0],
96 [6.0, 5.0],
97 [6.0, 6.0],
98 [7.0, 7.0],
99 ];
100 let reg = array![
101 [5.0, 4.0],
102 [5.0, 5.0],
103 [6.0, 5.0],
104 [7.0, 6.0],
105 [8.0, 8.0],
106 [9.0, 9.0],
107 [10.0, 11.0],
108 [12.0, 13.0],
109 ];
110 let want = [0.54285538831644, 0.550748101666388];
111 assert!(close(&propexpr(®, &neg), &want, 1e-9));
112 }
113}