use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use ontolius::{ontology::HierarchyWalks, Identified, TermId};
use phenotypes::Observable;
use crate::TermPair;
#[derive(Debug, Clone, PartialEq)]
pub struct TermIC {
pub present: f64,
pub excluded: f64,
}
pub struct CohortIcCalculator<O> {
hpo: Arc<O>,
module_root: TermId,
}
impl<O> CohortIcCalculator<O> {
pub fn new(hpo: Arc<O>, module_root: TermId) -> Self {
CohortIcCalculator { hpo, module_root }
}
}
#[derive(Debug, Default)]
struct TermCount {
present: u32,
excluded: u32,
}
impl<O> CohortIcCalculator<O>
where
O: HierarchyWalks,
{
pub fn compute_ic<C, M, A>(&self, cohort: C) -> HashMap<TermId, TermIC>
where
C: AsRef<[M]>,
M: AsRef<[A]>,
A: Identified + Observable,
{
let mut module_term_ids = HashSet::new();
module_term_ids.extend(self.hpo.iter_term_and_descendant_ids(&self.module_root));
let mut idx2count: HashMap<_, TermCount> = HashMap::with_capacity(module_term_ids.len());
for member in cohort.as_ref() {
for annotation in member.as_ref() {
let term_id = annotation.identifier();
if module_term_ids.contains(term_id) {
if annotation.is_present() {
for anc in self.hpo.iter_term_and_ancestor_ids(term_id) {
if module_term_ids.contains(anc) {
idx2count.entry(anc).or_default().present += 1;
}
}
} else {
for desc in self.hpo.iter_term_and_descendant_ids(term_id) {
idx2count.entry(desc).or_default().excluded += 1;
}
}
}
}
}
if idx2count.is_empty() {
return HashMap::new();
}
let pop_present_count = idx2count[&self.module_root].present as f64;
let pop_excluded_count = idx2count
.values()
.max_by_key(|&count| count.excluded)
.map(|count| count.excluded)
.expect("Idx2count should not be empty") as f64;
idx2count
.into_iter()
.map(|(term_id, count)| {
(
Clone::clone(term_id),
TermIC {
present: f64::log2(pop_present_count / count.present as f64),
excluded: f64::log2(pop_excluded_count / count.excluded as f64),
},
)
})
.collect()
}
pub fn compute_ic_mica<C, M, A>(&self, cohort: C) -> HashMap<TermPair, f64>
where
C: AsRef<[M]>,
M: AsRef<[A]>,
A: Identified + Observable,
{
let mut cohort_ig = HashSet::new();
for member in cohort.as_ref() {
for pf in member.as_ref() {
if pf.is_present() {
cohort_ig.extend(
self.hpo
.iter_term_and_ancestor_ids(pf.identifier())
.cloned(),
);
}
}
}
let ic = self.compute_ic(cohort);
let mut module = Vec::new();
let mut ic_micas = HashMap::new();
for module_root in self.hpo.iter_child_ids(&self.module_root) {
module.extend(
self.hpo
.iter_term_and_descendant_ids(module_root)
.filter(|ti| cohort_ig.contains(ti)),
);
for (i, &left) in module.iter().enumerate() {
for &right in &module[i..] {
if let Some(ic_mica) = common_ancestors(left, right, self.hpo.as_ref())
.into_iter()
.flat_map(|t| ic.get(t).map(|tic| tic.present))
.filter(|&f| f > 0.)
.reduce(f64::max)
{
let key = TermPair::from((left, right));
ic_micas
.entry(key)
.and_modify(|val: &mut f64| *val = val.max(ic_mica))
.or_insert(ic_mica);
}
}
}
module.clear();
}
ic_micas
}
}
fn common_ancestors<'a, O>(left: &'a TermId, right: &'a TermId, hpo: &'a O) -> Vec<&'a TermId>
where
O: HierarchyWalks,
{
let work: Vec<_> = hpo.iter_term_and_ancestor_ids(left).collect();
hpo.iter_term_and_ancestor_ids(right)
.filter(|x| work.contains(x))
.collect()
}
#[cfg(test)]
mod tests {
use std::{
fs::File,
io::BufReader,
sync::{Arc, OnceLock},
};
use flate2::bufread::GzDecoder;
use ontolius::{
common::hpo::PHENOTYPIC_ABNORMALITY, io::OntologyLoaderBuilder,
ontology::csr::MinimalCsrOntology, TermId,
};
use crate::{
ic::{CohortIcCalculator, TermPair},
subjects::fbn1_ectopia_lentis_subjects,
};
static HPO: OnceLock<Arc<MinimalCsrOntology>> = OnceLock::new();
fn load_hpo() -> Arc<MinimalCsrOntology> {
let path = "resources/hp.v2024-08-13.json.gz";
Arc::new(
OntologyLoaderBuilder::new()
.obographs_parser()
.build()
.load_from_read(GzDecoder::new(BufReader::new(File::open(path).unwrap())))
.expect("Should be loadable"),
)
}
#[test]
fn test_cohort_ic_calculator() {
let hpo = Arc::clone(HPO.get_or_init(load_hpo));
let fbn1 = fbn1_ectopia_lentis_subjects();
let pa = PHENOTYPIC_ABNORMALITY.clone();
let calculator = CohortIcCalculator::new(hpo, pa);
let ic_container = calculator.compute_ic(&fbn1);
assert_eq!(ic_container.len(), 178);
assert!(!ic_container
.values()
.any(|term_ic| term_ic.present.is_nan() || term_ic.excluded.is_nan()));
let pa_ic = ic_container.get(&PHENOTYPIC_ABNORMALITY);
assert!(pa_ic.is_some());
if let Some(pa_ic) = pa_ic {
assert_eq!(pa_ic.present, 0.);
assert_eq!(pa_ic.excluded, f64::INFINITY);
}
let myopia: TermId = "HP:0000545".parse().unwrap();
let myopia_ic = ic_container.get(&myopia);
assert!(myopia_ic.is_some());
if let Some(myopia_ic) = myopia_ic {
assert_eq!(myopia_ic.present, 3.0588936890535687);
assert_eq!(myopia_ic.excluded, 1.3219280948873624);
}
let ectopia_lentis: TermId = "HP:0001083".parse().unwrap();
let el_ic = ic_container.get(&ectopia_lentis);
assert!(el_ic.is_some());
if let Some(el_ic) = el_ic {
assert_eq!(el_ic.present, 2.3219280948873622);
assert_eq!(el_ic.excluded, f64::INFINITY);
}
}
#[test]
fn test_compute_ic_mica() {
let hpo = Arc::clone(HPO.get_or_init(load_hpo));
let fbn1 = fbn1_ectopia_lentis_subjects();
let cic = CohortIcCalculator::new(hpo, PHENOTYPIC_ABNORMALITY.clone());
let ic_mica = cic.compute_ic_mica(&fbn1);
let ectopia_lentis: TermId = "HP:0001083".parse().unwrap();
assert_eq!(
ic_mica.get(&TermPair::from(ectopia_lentis.clone())),
Some(&2.3219280948873622),
);
let myopia: TermId = "HP:0000545".parse().unwrap();
assert_eq!(
ic_mica.get(&TermPair::from(myopia.clone())),
Some(&3.0588936890535687),
);
let abn_eye_physiology: TermId = "HP:0012373".parse().unwrap();
assert_eq!(
ic_mica.get(&TermPair::from(abn_eye_physiology.clone())),
Some(&2.321928094887362),
);
let abn_of_the_eye: TermId = "HP:0000478".parse().unwrap();
assert_eq!(
ic_mica.get(&TermPair::from(abn_of_the_eye.clone())),
Some(&1.3219280948873624),
);
assert_eq!(
ic_mica.get(&TermPair::from((&myopia, &ectopia_lentis))),
Some(&1.3219280948873624),
);
let striae_distensae: TermId = "HP:0001065".parse().unwrap();
assert_eq!(
ic_mica.get(&TermPair::from(striae_distensae.clone())),
Some(&2.643856189774725),
);
assert_eq!(
ic_mica.get(&TermPair::from(PHENOTYPIC_ABNORMALITY.clone())),
None,
);
assert_eq!(
ic_mica.get(&TermPair::from((
striae_distensae.clone(),
ectopia_lentis.clone()
))),
None,
);
}
}