1use std::collections::HashMap;
15use std::marker::PhantomData;
16
17use crate::annotations::{AnnotationId, Disease, GeneId, OmimDiseaseId, OrphaDiseaseId};
18use crate::HpoTerm;
19
20pub mod hypergeom;
21mod linkage;
22pub use linkage::cluster;
23pub use linkage::Linkage;
24
25#[derive(Debug)]
31pub struct Enrichment<T> {
32 annotation: T,
33 pvalue: f64,
34 count: u64,
35 #[allow(clippy::struct_field_names)]
36 enrichment: f64,
37}
38
39impl<T> Enrichment<T> {
40 pub fn pvalue(&self) -> f64 {
45 self.pvalue
46 }
47
48 pub fn enrichment(&self) -> f64 {
50 self.enrichment
51 }
52
53 pub fn id(&self) -> &T {
55 &self.annotation
56 }
57
58 pub fn count(&self) -> u64 {
60 self.count
61 }
62}
63
64impl Enrichment<GeneId> {
65 pub fn gene(gene: GeneId, pvalue: f64, count: u64, enrichment: f64) -> Self {
67 Self {
68 annotation: gene,
69 pvalue,
70 count,
71 enrichment,
72 }
73 }
74}
75
76impl Enrichment<OmimDiseaseId> {
77 pub fn disease(disease: OmimDiseaseId, pvalue: f64, count: u64, enrichment: f64) -> Self {
79 Self {
80 annotation: disease,
81 pvalue,
82 count,
83 enrichment,
84 }
85 }
86}
87
88impl<T: AnnotationId> Enrichment<T> {
89 pub fn annotation(annotation: T, pvalue: f64, count: u64, enrichment: f64) -> Self {
91 Self {
92 annotation,
93 pvalue,
94 count,
95 enrichment,
96 }
97 }
98}
99
100struct SampleSet<T> {
101 size: u64,
103 counts: HashMap<u32, u64>,
105 phantom: PhantomData<T>,
106}
107
108fn calculate_counts<
109 'a,
110 U: FnMut(HpoTerm<'a>) -> IT,
111 I: IntoIterator<Item = HpoTerm<'a>>,
112 IT: IntoIterator<Item = u32>,
113>(
114 terms: I,
115 mut iter: U,
116) -> (u64, HashMap<u32, u64>) {
117 let mut size = 0u64;
118 let mut counts: HashMap<u32, u64> = HashMap::new();
119 for term in terms {
120 size += 1;
121 for id in iter(term) {
122 counts
123 .entry(id)
124 .and_modify(|count| *count += 1)
125 .or_insert(1);
126 }
127 }
128 (size, counts)
129}
130
131impl<'a> SampleSet<GeneId> {
132 pub fn gene<I: IntoIterator<Item = HpoTerm<'a>>>(terms: I) -> Self {
134 let term2geneid = |term: HpoTerm<'a>| term.genes().map(|d| d.id().as_u32());
135
136 let (size, counts) = calculate_counts(terms, term2geneid);
137 Self {
138 size,
139 counts,
140 phantom: PhantomData,
141 }
142 }
143}
144
145impl<'a> SampleSet<OmimDiseaseId> {
146 pub fn omim_disease<I: IntoIterator<Item = HpoTerm<'a>>>(terms: I) -> Self {
148 let term2omimid = |term: HpoTerm<'a>| term.omim_diseases().map(|d| d.id().as_u32());
149 let (size, counts) = calculate_counts(terms, term2omimid);
150 Self {
151 size,
152 counts,
153 phantom: PhantomData,
154 }
155 }
156}
157
158impl<'a> SampleSet<OrphaDiseaseId> {
159 pub fn orpha_disease<I: IntoIterator<Item = HpoTerm<'a>>>(terms: I) -> Self {
161 let term2omimid = |term: HpoTerm<'a>| term.orpha_diseases().map(|d| d.id().as_u32());
162 let (size, counts) = calculate_counts(terms, term2omimid);
163 Self {
164 size,
165 counts,
166 phantom: PhantomData,
167 }
168 }
169}
170
171impl<T: AnnotationId> SampleSet<T> {
172 fn len(&self) -> u64 {
174 self.size
175 }
176
177 #[allow(dead_code)]
181 fn is_empty(&self) -> bool {
182 self.size == 0
183 }
184 fn get(&self, key: &T) -> Option<&u64> {
190 self.counts.get(&key.as_u32())
191 }
192
193 #[allow(dead_code)]
200 fn frequency(&self, key: &T) -> Option<f64> {
201 self.counts
202 .get(&key.as_u32())
203 .map(|count| f64_from_u64(*count) / f64_from_u64(self.size))
204 }
205
206 #[allow(dead_code)]
208 fn frequencies(&'_ self) -> Frequencies<'_, T> {
209 Frequencies::new(self.counts.iter(), self.size, self.phantom)
210 }
211}
212
213impl<'a, T: AnnotationId> IntoIterator for &'a SampleSet<T> {
214 type Item = (T, u64);
215 type IntoIter = Counts<'a, T>;
216 fn into_iter(self) -> Self::IntoIter {
217 Counts::new(self.counts.iter(), self.phantom)
218 }
219}
220
221struct Frequencies<'a, K> {
223 inner: std::collections::hash_map::Iter<'a, u32, u64>,
224 total: u64,
225 phantom: PhantomData<K>,
226}
227
228impl<'a, K> Frequencies<'a, K> {
229 pub fn new(
230 inner: std::collections::hash_map::Iter<'a, u32, u64>,
231 total: u64,
232 phantom: PhantomData<K>,
233 ) -> Self {
234 Self {
235 inner,
236 total,
237 phantom,
238 }
239 }
240}
241
242impl<K: AnnotationId> Iterator for Frequencies<'_, K> {
243 type Item = (K, f64);
244 fn next(&mut self) -> Option<Self::Item> {
245 self.inner
246 .next()
247 .map(|(k, v)| (K::from(*k), f64_from_u64(*v) / f64_from_u64(self.total)))
248 }
249}
250
251struct Counts<'a, K> {
253 inner: std::collections::hash_map::Iter<'a, u32, u64>,
254 phantom: PhantomData<K>,
255}
256
257impl<'a, K> Counts<'a, K> {
258 pub fn new(
259 inner: std::collections::hash_map::Iter<'a, u32, u64>,
260 phantom: PhantomData<K>,
261 ) -> Self {
262 Self { inner, phantom }
263 }
264}
265
266impl<K: AnnotationId> Iterator for Counts<'_, K> {
267 type Item = (K, u64);
268 fn next(&mut self) -> Option<Self::Item> {
269 self.inner.next().map(|(k, v)| (K::from(*k), *v))
270 }
271}
272
273fn f64_from_u64(n: u64) -> f64 {
277 let intermediate: u32 = n
278 .try_into()
279 .expect("cannot safely create f64 from large u64");
280 intermediate.into()
281}
282
283fn f64_from_usize(n: usize) -> f64 {
287 let intermediate: u32 = n
288 .try_into()
289 .expect("cannot safely create f64 from large u64");
290 intermediate.into()
291}
292
293#[cfg(test)]
294mod test {
295 use super::*;
296
297 #[test]
298 fn iterate_frequencies() {
299 let mut map = HashMap::new();
300 map.insert(12u32, 12u64);
301 map.insert(21u32, 21u64);
302
303 let mut iter: Frequencies<'_, OmimDiseaseId> = Frequencies::new(map.iter(), 3, PhantomData);
304 match iter.next() {
305 Some((key, x)) if key == OmimDiseaseId::from(12) => {
306 assert!((x - 4.0).abs() < f64::EPSILON);
307 }
308 Some((key, x)) if key == OmimDiseaseId::from(21) => {
309 assert!((x - 7.0).abs() < f64::EPSILON);
310 }
311 _ => panic!("invalid"),
312 }
313 match iter.next() {
314 Some((key, x)) if key == OmimDiseaseId::from(12) => {
315 assert!((x - 4.0).abs() < f64::EPSILON);
316 }
317 Some((key, x)) if key == OmimDiseaseId::from(21) => {
318 assert!((x - 7.0).abs() < f64::EPSILON);
319 }
320 _ => panic!("invalid"),
321 }
322 assert!(iter.next().is_none());
323 }
324
325 #[test]
326 fn iterate_counts() {
327 let mut map = HashMap::new();
328 map.insert(12u32, 12u64);
329 map.insert(21u32, 21u64);
330
331 let mut iter: Counts<'_, OmimDiseaseId> = Counts::new(map.iter(), PhantomData);
332 match iter.next() {
333 Some((key, x)) if key == OmimDiseaseId::from(12) => assert_eq!(x, 12),
334 Some((key, x)) if key == OmimDiseaseId::from(21) => assert_eq!(x, 21),
335 _ => panic!("invalid"),
336 }
337 match iter.next() {
338 Some((key, x)) if key == OmimDiseaseId::from(12) => assert_eq!(x, 12),
339 Some((key, x)) if key == OmimDiseaseId::from(21) => assert_eq!(x, 21),
340 _ => panic!("invalid"),
341 }
342 assert!(iter.next().is_none());
343 }
344}