1use ahash::{HashMap, HashMapExt};
2
3use itertools::Itertools;
4use logging_timer::time;
5use statrs::function::factorial::ln_binomial;
6
7#[time("debug")]
8pub fn highest_hit_prob_per_reference(
9    total_num_k_mers: u16,
10    num_trials: usize,
11    intersection_sizes: &[u16],
12) -> Vec<f64> {
13    let intersection_size_counts = {
14        let mut counts: HashMap<u16, usize> = HashMap::new();
15        intersection_sizes
16            .iter()
17            .for_each(|item| *counts.entry(*item).or_default() += 1);
18        counts
19    };
20    let num_possible_kmer_sets = ln_binomial(
21        total_num_k_mers as u64 + num_trials as u64 - 1,
22        num_trials as u64,
23    );
24    let highest_hit_probs = if intersection_size_counts
25        .iter()
26        .any(|(&i, _)| i == total_num_k_mers)
27    {
28        intersection_size_counts
29            .iter()
30            .map(|(&n_intersections, _)| {
31                (
32                    n_intersections,
33                    only_last_pmf(
34                        total_num_k_mers as u64,
35                        num_trials as u64,
36                        n_intersections as u64,
37                        num_possible_kmer_sets,
38                    ),
39                )
40            })
41            .collect::<HashMap<u16, f64>>()
42    } else {
43        let pmfs: Vec<(u16, Vec<f64>)> = iterative_pmfs_ln(
44            total_num_k_mers as u64,
45            num_trials as u64,
46            &intersection_size_counts,
47            num_possible_kmer_sets,
48        );
49        let cmfs: Vec<Vec<f64>> = pmfs
50            .iter()
51            .map(|(_, v)| {
52                v.iter()
53                    .scan(0.0, |sum, &pmf| {
54                        if pmf != f64::NEG_INFINITY {
55                            *sum += pmf.exp();
56                        }
57                        Some(sum.ln())
58                    })
59                    .collect_vec()
60            })
61            .collect_vec();
62        let cmf_prod_components = (0..=num_trials)
63            .map(|i| {
64                intersection_size_counts
65                    .iter()
66                    .zip_eq(cmfs.iter())
67                    .map(|((_, &count), cmf)| {
68                        let x = unsafe { *cmf.get_unchecked(i) };
69                        (count as f64) * x
70                    })
71                    .sum::<f64>()
72            })
73            .collect_vec();
74        pmfs.into_iter()
75            .zip_eq(cmfs.into_iter())
76            .map(|((i, pmf), cmf)| {
77                (
78                    i,
79                    itertools::izip!(pmf.into_iter(), cmf.into_iter(), cmf_prod_components.iter())
80                        .map(|(p, c, &prod_components)| {
81                            if c == f64::NEG_INFINITY || prod_components == f64::NEG_INFINITY {
82                                0.0
83                            } else {
84                                (p + prod_components - c).exp()
85                            }
86                        })
87                        .sum::<f64>(),
88                )
89            })
90            .collect::<HashMap<u16, f64>>()
91    };
92    let highest_hit_probs = intersection_sizes
93        .iter()
94        .map(|&n_intersections| highest_hit_probs[&n_intersections])
95        .collect_vec();
96
97    let probs_sum: f64 = highest_hit_probs.iter().sum();
98    assert!(probs_sum > 0.0);
99    highest_hit_probs
100        .into_iter()
101        .map(|v| v / probs_sum)
102        .collect_vec()
103}
104
105fn only_last_pmf(
106    total_num_k_mers: u64,
107    num_trials: u64,
108    num_intersections: u64,
109    num_possible_kmer_sets: f64,
110) -> f64 {
111    if num_intersections == total_num_k_mers {
112        return 1.0;
113    }
114    if num_intersections == 0 {
115        return 0.0;
116    }
117    let num_possible_matches = ln_binomial(num_intersections + num_trials - 1, num_trials);
118    (num_possible_matches - num_possible_kmer_sets).exp()
119}
120
121fn iterative_pmfs_ln(
122    total_num_k_mers: u64,
123    num_trials: u64,
124    intersection_sizes: &HashMap<u16, usize>,
125    num_possible_kmer_sets: f64,
126) -> Vec<(u16, Vec<f64>)> {
127    intersection_sizes
128        .iter()
129        .map(|(&num_intersections, _)| {
130            if num_intersections as u64 == total_num_k_mers {
131                let mut res = vec![f64::NEG_INFINITY; num_trials as usize + 1];
132                res[num_trials as usize] = 0.0;
133                (num_intersections, res)
134            } else if num_intersections == 0 {
135                let mut res = vec![f64::NEG_INFINITY; num_trials as usize + 1];
136                res[0] = 0.0;
137                (num_intersections, res)
138            } else {
139                let num_possible_matches = (1..=num_trials).scan(0.0, |sum, i| {
140                    *sum += ((num_intersections as u64 + i - 1) as f64 / i as f64).ln();
141                    Some(*sum)
142                });
143                let impossible_init = ln_binomial(
144                    total_num_k_mers - num_intersections as u64 + num_trials - 1,
145                    num_trials,
146                );
147                let num_impossible_matches = (1..num_trials)
148                    .scan(impossible_init, |sum, i| {
149                        *sum -= ((total_num_k_mers - num_intersections as u64 + num_trials - i)
150                            as f64
151                            / (num_trials - i + 1) as f64)
152                            .ln();
153                        Some(*sum)
154                    })
155                    .chain([0.0]);
156                (
157                    num_intersections,
158                    [impossible_init - num_possible_kmer_sets]
159                        .into_iter()
160                        .chain(
161                            num_possible_matches
162                                .zip_eq(num_impossible_matches)
163                                .map(|(p, i)| p + i - num_possible_kmer_sets),
164                        )
165                        .collect_vec(),
166                )
167            }
168        })
169        .collect()
170}
171
172#[cfg(test)]
173mod tests {
174    use ahash::HashMap;
175    use itertools::Itertools;
176    use statrs::{assert_almost_eq, function::factorial::ln_binomial};
177
178    use crate::prob::iterative_pmfs_ln;
179
180    use super::highest_hit_prob_per_reference;
181
182    fn pmf(
183        total_num_k_mers: u64,
184        i: u64,
185        num_trials: u64,
186        num_intersections: u64,
187        num_possible_kmer_sets: f64,
188    ) -> f64 {
189        if num_intersections == total_num_k_mers {
190            if i == num_trials {
191                return 1.0;
192            }
193            return 0.0;
194        }
195        if num_intersections == 0 {
196            if i == 0 {
197                return 1.0;
198            }
199            return 0.0;
200        }
201        let num_possible_matches = ln_binomial(num_intersections + i - 1, i);
202        let num_impossible_matches = ln_binomial(
203            (total_num_k_mers - num_intersections) + (num_trials - i) - 1,
204            num_trials - i,
205        );
206        (num_possible_matches + num_impossible_matches - num_possible_kmer_sets).exp()
207    }
208    #[test]
209    fn test_pmf() {
210        let num_possible_kmer_sets = ln_binomial(200 + 32 - 1, 32);
211        let p = iterative_pmfs_ln(
212            200,
213            32,
214            &HashMap::from_iter([(50, 4)]),
215            num_possible_kmer_sets,
216        );
217        let p2 = (0..=32)
218            .map(|i| pmf(200, i, 32, 50, num_possible_kmer_sets))
219            .collect_vec();
220        assert_almost_eq!(p[0].1.iter().map(|p| p.exp()).sum::<f64>(), 1.0, 1e-7);
221        assert_almost_eq!(p2.iter().sum::<f64>(), 1.0, 1e-7);
222        p[0].1
224            .iter()
225            .zip(p2)
226            .for_each(|(&a, b)| assert_almost_eq!(a.exp(), b, 1e-7));
227    }
228
229    #[test]
230    fn test_hit_prob() {
231        let probs = highest_hit_prob_per_reference(400, 200, &(0..=400).collect_vec());
232        dbg!(&probs);
233        assert_almost_eq!(probs.iter().sum::<f64>(), 1.0, 1e-7);
234        assert!(probs.windows(2).all(|w| w[0] <= w[1]));
235    }
236}