use itertools::Itertools;
use num::{Zero, Float, NumCast, One};
use std::fmt::Debug;
use vers_vecs::BitVec;
#[cfg(feature = "non_crypto_hash")]
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
#[cfg(not(feature = "non_crypto_hash"))]
use std::collections::{HashMap, HashSet};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::prelude::*;
pub type TreeNodeZeta<T> = <<T as RootedTree>::Node as RootedZetaNode>::Zeta;
pub trait PathFunction: RootedTree
where
<Self as RootedTree>::Node: RootedZetaNode,
{
fn set_zeta(
&mut self,
zeta_func: fn(&Self, TreeNodeID<Self>) -> TreeNodeZeta<Self>,
) -> Option<()> {
let node_ids = self.get_node_ids().collect_vec();
for node_id in node_ids {
let zeta = zeta_func(self, node_id);
self.set_node_zeta(node_id, Some(zeta))?;
}
Some(())
}
fn get_zeta(&self, node_id: TreeNodeID<Self>) -> Option<TreeNodeZeta<Self>> {
self.get_node(node_id)?.get_zeta()
}
fn is_zeta_set(&self, node_id: TreeNodeID<Self>) -> bool {
self.get_node(node_id).unwrap().is_zeta_set()
}
fn is_all_zeta_set(&self) -> bool {
!self.get_nodes().any(|x| !x.is_zeta_set())
}
fn set_node_zeta(
&mut self,
node_id: TreeNodeID<Self>,
zeta: Option<TreeNodeZeta<Self>>,
) -> Option<()> {
self.get_node_mut(node_id)?.set_zeta(zeta);
Some(())
}
}
pub trait DistanceMatrix: RootedWeightedTree
where
<Self as RootedTree>::Node: RootedWeightedNode,
{
fn matrix(&self) -> Vec<Vec<TreeNodeWeight<Self>>>;
fn pairwise_distance(
&self,
node_id_1: TreeNodeID<Self>,
node_id_2: TreeNodeID<Self>,
) -> TreeNodeWeight<Self>;
}
pub trait RobinsonFoulds
where
Self: RootedTree + RootedMetaTree + Clusters,
<Self as RootedTree>::Node: RootedMetaNode,
{
fn rf(&self, tree: &Self) -> usize {
let mut dist = 0;
let mut all_taxa: HashSet<&TreeNodeMeta<Self>> = self.get_taxa_space().collect();
all_taxa.extend(tree.get_taxa_space());
let num_taxa = all_taxa.len();
let all_taxa_map: HashMap<&TreeNodeMeta<Self>, usize> = all_taxa
.into_iter()
.enumerate()
.map(|x| (x.1, x.0))
.collect();
let mut self_bps: HashMap<TreeNodeID<Self>, BitVec> = vec![].into_iter().collect();
let mut self_out_bps: HashSet<BitVec> = vec![].into_iter().collect();
for n_id in self.postord_ids(self.get_root_id()) {
let mut bp = BitVec::from_zeros(num_taxa);
let mut bp_rev = BitVec::from_ones(num_taxa);
match self.is_leaf(n_id) {
true => {
let leaf_meta = self.get_node_taxa(n_id).unwrap();
bp.flip_bit(*all_taxa_map.get(leaf_meta).unwrap());
let _ = bp_rev.apply_mask_xor(&bp);
self_bps.insert(n_id, bp.clone());
self_out_bps.insert(bp);
self_out_bps.insert(bp_rev);
}
false => {
if n_id==self.get_root_id(){
continue;
}
else{
self.get_node_children_ids(n_id)
.map(|x| self_bps.get(&x).unwrap())
.for_each(|x| {let _ = bp.apply_mask_or(x);});
let _ = bp_rev.apply_mask_xor(&bp);
self_bps.insert(n_id, bp.clone());
self_out_bps.insert(bp);
self_out_bps.insert(bp_rev);
}
}
};
}
let mut tree_bps: HashMap<TreeNodeID<Self>, BitVec> = vec![].into_iter().collect();
let mut tree_out_bps: HashSet<BitVec> = vec![].into_iter().collect();
for n_id in tree.postord_ids(tree.get_root_id()) {
let mut bp = BitVec::from_zeros(num_taxa);
let mut bp_rev = BitVec::from_ones(num_taxa);
match tree.is_leaf(n_id) {
true => {
let leaf_meta = tree.get_node_taxa(n_id).unwrap();
bp.flip_bit(*all_taxa_map.get(leaf_meta).unwrap());
let _ = bp_rev.apply_mask_xor(&bp);
tree_bps.insert(n_id, bp.clone());
tree_out_bps.insert(bp);
tree_out_bps.insert(bp_rev);
}
false => {
if n_id==tree.get_root_id(){
continue;
}
else {
tree.get_node_children_ids(n_id)
.map(|x| tree_bps.get(&x).unwrap())
.for_each(|x| {let _ = bp.apply_mask_or(x);});
let _ = bp_rev.apply_mask_xor(&bp);
tree_bps.insert(n_id, bp.clone());
tree_out_bps.insert(bp);
tree_out_bps.insert(bp_rev);
}
}
};
}
for i in self_out_bps.iter() {
if tree_out_bps.contains(i){
continue;
}
else{
dist += 1;
}
}
for i in tree_out_bps.iter() {
if self_out_bps.contains(i){
continue;
}
else{
dist += 1;
}
}
dist/2
}
}
pub trait ClusterMatching
where
Self: RootedTree + RootedMetaTree + Clusters,
<Self as RootedTree>::Node: RootedMetaNode,
{
fn cm(&self, tree: &Self) -> usize {
let self_clusters = self
.get_node_ids()
.map(|node_id| {
self.get_cluster_ids(node_id)
.map(|id| self.get_node_taxa(id).unwrap())
.collect_vec()
})
.collect::<HashSet<_>>();
let tree_clusters = tree
.get_node_ids()
.map(|node_id| {
tree.get_cluster_ids(node_id)
.map(|id| tree.get_node_taxa(id).unwrap())
.collect_vec()
})
.collect::<HashSet<_>>();
self_clusters.difference(&tree_clusters).collect_vec().len()
+ tree_clusters.difference(&self_clusters).collect_vec().len()
}
}
pub trait ClusterAffinity
where
Self: RootedTree + RootedMetaTree + Clusters,
<Self as RootedTree>::Node: RootedMetaNode,
{
fn ca(&self, tree: &Self) -> usize {
let mut dist = 0;
let mut t1_size_map: HashMap<TreeNodeID<Self>,usize> = [].into_iter().collect::<HashMap<_,_>>();
let mut t2_size_map: HashMap<TreeNodeID<Self>,usize> = [].into_iter().collect::<HashMap<_,_>>();
let mut intersection_map: HashMap<(TreeNodeID<Self>,TreeNodeID<Self>),usize> = [].into_iter().collect::<HashMap<_,_>>();
for v in self.postord_ids(self.get_root_id()){
let mut mindist = usize::MAX;
let vsize;
if self.is_leaf(v){
vsize = 1;
}else{
vsize = self.get_node_children_ids(v).map(|x| t1_size_map.get(&x).unwrap()).sum();
}
t1_size_map.insert(v,vsize);
for c in tree.postord_ids(tree.get_root_id()){
let mut size = 0;
let mut intersection = 0;
if tree.is_leaf(c){
size = 1;
if self.is_leaf(v){
if self.get_node_taxa(v).unwrap() == tree.get_node_taxa(c).unwrap(){
intersection = 1
}
else{
intersection = 0
}
} else {
for ch in self.get_node_children_ids(v) {
if *intersection_map.get(&(ch,c)).unwrap_or(&0) > 0 {
intersection = 1;
break;
}
}
}
} else {
for cch in tree.get_node_children_ids(c) {
size += t2_size_map.get(&cch).unwrap();
intersection += intersection_map.get(&(v,cch)).unwrap();
}
}
t2_size_map.insert(c,size);
intersection_map.insert((v,c),intersection);
let cdist = size + vsize - (2*intersection);
if mindist > cdist{
mindist = cdist;
}
}
dist += mindist
}
return dist;
}
}
pub trait WeightedRobinsonFoulds
where
Self: RootedWeightedTree + RootedMetaTree + Clusters,
<Self as RootedTree>::Node: RootedWeightedNode + RootedMetaNode,
{
fn wrfs(&self, tree: &Self) -> TreeNodeWeight<Self>;
}
pub trait CopheneticDistance:
PathFunction + RootedMetaTree + Clusters + Ancestors + ContractTree + Debug
where
<Self as RootedTree>::Node: RootedMetaNode + RootedZetaNode,
TreeNodeZeta<Self>: NodeWeight,
{
fn get_zeta_taxa(
&self,
taxa: &TreeNodeMeta<Self>,
) -> TreeNodeZeta<Self> {
self.get_zeta(self.get_taxa_node_id(taxa).unwrap()).unwrap()
}
fn compute_norm(
vector: impl Iterator<Item = TreeNodeZeta<Self>>,
norm: u32,
) -> TreeNodeZeta<Self> {
if norm==0{
return vector.fold(<TreeNodeZeta<Self>>::zero(), |acc, x| acc.max(x));
}
if norm == 1 {
return vector.sum();
}
vector
.map(|x| {
let mut out = <TreeNodeZeta<Self>>::one();
for _ in 0..norm{
out = out* x;
}
out
})
.sum::<TreeNodeZeta<Self>>()
.powf(
<TreeNodeZeta<Self> as NumCast>::from(norm)
.unwrap()
.powi(-1),
)
}
#[cfg(feature = "parallel")]
fn compute_norm_par(
vector: Vec<TreeNodeZeta<Self>>,
norm: u32,
) -> TreeNodeZeta<Self> {
if norm == 1 {
return vector.into_iter().map(|x| x.clone()).sum();
}
vector
.par_iter()
.map(|x| {
x.clone().powi(norm as i32)
})
.sum::<TreeNodeZeta<Self>>()
.powf(
<TreeNodeZeta<Self> as NumCast>::from(norm)
.unwrap()
.powi(-1),
)
}
fn cophen_dist<'a>(
&'a self,
tree: &'a Self,
norm: u32,
) -> TreeNodeZeta<Self> {
if !self.is_all_zeta_set() || !tree.is_all_zeta_set() {
panic!("Zeta values not set");
}
let binding1 = self
.get_taxa_space()
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let binding2 = tree
.get_taxa_space()
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let taxa_set = binding1.intersection(&binding2).cloned();
self.cophen_dist_by_taxa(tree, norm, taxa_set)
}
#[cfg(feature = "parallel")]
fn cophen_dist_par<'a>(
&'a self,
tree: &'a Self,
norm: u32,
) -> TreeNodeZeta<Self> {
if !self.is_all_zeta_set() || !tree.is_all_zeta_set() {
panic!("Zeta values not set");
}
let binding1 = self
.get_taxa_space()
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let binding2 = tree
.get_taxa_space()
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let taxa_set = binding1.intersection(&binding2).cloned().collect_vec();
self.cophen_dist_by_taxa_par(tree, norm, taxa_set.into_iter())
}
#[cfg(feature = "parallel")]
fn cophen_dist_by_taxa_par<'a>(
&'a self,
tree: &'a Self,
norm: u32,
taxa_set: impl Iterator<Item = &'a TreeNodeMeta<Self>> + Send,
) -> TreeNodeZeta<Self> {
let cophen_vec: Vec<TreeNodeZeta<Self>> = taxa_set
.combinations_with_replacement(2)
.par_bridge()
.map(|x| match x[0] == x[1] {
true => {
let zeta_1 = self.get_zeta_taxa(x[0]);
let zeta_2 = tree.get_zeta_taxa(x[0]);
(zeta_1 - zeta_2).abs()
},
false => {
let self_ids = x
.iter()
.map(|a| self.get_taxa_node_id(a).unwrap())
.collect_vec();
let tree_ids = x
.iter()
.map(|a| tree.get_taxa_node_id(a).unwrap())
.collect_vec();
let t_lca_id = self.get_lca_id(self_ids.as_slice());
let t_hat_lca_id = tree.get_lca_id(tree_ids.as_slice());
let zeta_1 = self.get_zeta(t_lca_id).unwrap();
let zeta_2 = tree.get_zeta(t_hat_lca_id).unwrap();
(zeta_1 - zeta_2).abs()
},
})
.collect();
Self::compute_norm_par(cophen_vec, norm)
}
fn cophen_dist_by_taxa<'a>(
&'a self,
tree: &'a Self,
norm: u32,
taxa_set: impl Iterator<Item = &'a TreeNodeMeta<Self>>,
) -> TreeNodeZeta<Self> {
let cophen_vec = taxa_set
.combinations_with_replacement(2)
.map(|x| match x[0] == x[1] {
true => {
let zeta_1 = self.get_zeta_taxa(x[0]);
let zeta_2 = tree.get_zeta_taxa(x[0]);
(zeta_1 - zeta_2).abs()
},
false => {
let self_ids = x
.iter()
.map(|a| self.get_taxa_node_id(a).unwrap())
.collect_vec();
let tree_ids = x
.iter()
.map(|a| tree.get_taxa_node_id(a).unwrap())
.collect_vec();
let t_lca_id = self.get_lca_id(self_ids.as_slice());
let t_hat_lca_id = tree.get_lca_id(tree_ids.as_slice());
let zeta_1 = self.get_zeta(t_lca_id).unwrap();
let zeta_2 = tree.get_zeta(t_hat_lca_id).unwrap();
(zeta_1 - zeta_2).abs()
},
});
Self::compute_norm(cophen_vec, norm)
}
}