speedytree 0.1.0

Canonical and RapidNJ implementations of Neighbor-joining in Rust
Documentation
use std::collections::{HashMap, HashSet};

use bit_set::BitSet;
use bit_vec::BitVec;
use petgraph::stable_graph::EdgeIndex;

use crate::Tree;

fn count_leaves(x: &Tree) -> usize {
    let mut leaf_count = 0;
    for node in x.node_indices() {
        if x.neighbors(node).count() == 1 {
            leaf_count += 1;
        }
    }
    leaf_count
}

/// Calculate the [Branch-Score distance](https://www.cs.mcgill.ca/~birch/birchhomedir/doc/Phylip/treedist.html) between two trees. It takes the branch length into account. 
pub fn branch_score(a: &Tree, b: &Tree) -> f64 {
    let n_leaves = (count_leaves(a), count_leaves(b));
    assert_eq!(n_leaves.0, n_leaves.1);
    let n_leaves = n_leaves.0;
    let mut bits_a = HashMap::new();
    let mut bits_b = HashMap::new();
    a.edge_indices()
        .zip(a.edge_weights())
        .for_each(|(edge, w)| {
            bits_a.insert(collect_bit_vector(a, edge, n_leaves), *w);
        });
    b.edge_indices()
        .zip(b.edge_weights())
        .for_each(|(edge, w)| {
            bits_b.insert(collect_bit_vector(b, edge, n_leaves), *w);
        });

    // Get union of a and b keys
    let mut keys: HashSet<&BitVec> = HashSet::new();
    bits_a.keys().zip(bits_b.keys()).for_each(|(x, y)| {
        keys.insert(x);
        keys.insert(y);
    });

    let mut distance = 0.0;
    for key in keys {
        let a = bits_a.get(key).unwrap_or(&0.0);
        let b = bits_b.get(key).unwrap_or(&0.0);
        distance += (a - b).powi(2);
    }
    distance
}

/// Calculate the [Robinson-Foulds](https://en.wikipedia.org/wiki/Robinson%E2%80%93Foulds_metric) distance between two trees. It doesn't take branch length into account. 
pub fn robinson_foulds(a: &Tree, b: &Tree) -> usize {
    let n_leaves = (count_leaves(a), count_leaves(b));
    assert_eq!(n_leaves.0, n_leaves.1);
    let n_leaves = n_leaves.0;
    let bits_a: HashSet<BitVec> = HashSet::from_iter(
        a.edge_indices()
            .map(|edge| collect_bit_vector(a, edge, n_leaves)),
    );
    let bits_b: HashSet<BitVec> = HashSet::from_iter(
        b.edge_indices()
            .map(|edge| collect_bit_vector(b, edge, n_leaves)),
    );

    let mut distance = 0;
    for a in &bits_a {
        if !bits_b.contains(a) {
            distance += 1;
        }
    }
    for b in bits_b {
        if !bits_a.contains(&b) {
            distance += 1;
        }
    }
    distance
}

fn collect_bit_vector(
    tree: &petgraph::Graph<String, f64, petgraph::Undirected>,
    edge: EdgeIndex,
    n_leaves: usize,
) -> BitVec {
    // Collect all leaves
    let mut leaves = HashMap::new();
    tree.node_indices().enumerate().for_each(|(i, index)| {
        if !tree[index].is_empty() {
            leaves.insert(index, i);
        }
    });
    let mut left_nodes = HashSet::new();
    let (parent_left, parent_right) = tree.edge_endpoints(edge).expect("Valid edge");
    let mut left_queue = Vec::new();
    left_nodes.insert(parent_left);
    tree.neighbors(parent_left)
        .filter(|node| node != &parent_right)
        .for_each(|node| {
            left_nodes.insert(node);
            left_queue.push(node);
        });
    while let Some(node) = left_queue.pop() {
        for neighbor in tree.neighbors(node) {
            if !left_nodes.contains(&neighbor) {
                left_nodes.insert(neighbor);
                left_queue.push(neighbor);
            }
        }
    }
    let mut bit_vect = BitSet::with_capacity(n_leaves);
    for node in left_nodes {
        // If node is a leaf, then add it to the bit vector
        if let Some(i) = leaves.get(&node) {
            bit_vect.insert(*i);
        }
    }
    let mut bit_vect = bit_vect.into_bit_vec();
    if bit_vect[0] {
        bit_vect.negate();
    }
    bit_vect
}

#[cfg(test)]
mod tests {
    use rand::distributions::{Alphanumeric, DistString};

    use super::*;

    #[test]
    fn test_distance() {
        let names: Vec<String> = (0..4)
            .map(|_| Alphanumeric.sample_string(&mut rand::thread_rng(), 16))
            .collect();
        let mut t1 = Tree::new_undirected();
        {
            // Simulate random leaf node names

            let a = t1.add_node(names[0].to_owned());
            let b = t1.add_node(names[1].to_owned());
            let c = t1.add_node(names[2].to_owned());
            let d = t1.add_node(names[3].to_owned());
            let u = t1.add_node("".to_owned());
            let v = t1.add_node("".to_owned());
            t1.add_edge(a, v, 2.0);
            t1.add_edge(b, v, 1.0);
            t1.add_edge(c, u, 3.0);
            t1.add_edge(d, u, 10.0);
            t1.add_edge(v, u, 5.0);
        }
        let mut t2 = Tree::new_undirected();
        {
            let a = t2.add_node(names[0].to_owned());
            let b = t2.add_node(names[1].to_owned());
            let c = t2.add_node(names[2].to_owned());
            let d = t2.add_node(names[3].to_owned());
            let u = t2.add_node("".to_owned());
            let v = t2.add_node("".to_owned());
            t2.add_edge(a, v, 1.0);
            t2.add_edge(c, v, 1.0);
            t2.add_edge(b, u, 1.0);
            t2.add_edge(d, u, 1.0);
            t2.add_edge(v, u, 1.0);
        }
        assert_eq!(robinson_foulds(&t1, &t2), 2);
        assert_eq!(robinson_foulds(&t1, &t1), 0);
        assert_eq!(robinson_foulds(&t2, &t2), 0);
        //
        assert_eq!(branch_score(&t1, &t2), 112.0);
        assert_eq!(branch_score(&t1, &t1), 0.0);
        assert_eq!(branch_score(&t2, &t2), 0.0);
    }
}