rsomics-tree-tipdist 0.1.0

Patristic tip-to-tip distance matrix from a phylogenetic tree (sum of branch lengths between every pair of tips) — scikit-bio TreeNode.cophenet equivalent, byte-exact TSV
Documentation
use std::collections::HashSet;
use std::io::{self, Write};

use rsomics_phylo_tree::{NodeId, Tree};

mod pyfloat;

#[derive(Debug, thiserror::Error)]
pub enum TipDistError {
    #[error("tree has no named tips")]
    NoTips,
    #[error("tree contains duplicate tip name {0:?}")]
    DuplicateTip(String),
}

/// Symmetric patristic distance matrix with tip labels in skbio's postorder
/// tip order. `data` is row-major `n×n`.
pub struct TipDistMatrix {
    pub ids: Vec<String>,
    pub data: Vec<f64>,
    n: usize,
}

impl TipDistMatrix {
    #[must_use]
    pub fn len(&self) -> usize {
        self.n
    }

    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.n == 0
    }

    #[must_use]
    pub fn get(&self, i: usize, j: usize) -> f64 {
        self.data[i * self.n + j]
    }

    /// Write the matrix in skbio `DistanceMatrix.write` (lsmat) TSV form:
    /// a header row starting with a tab then the ids, then one labelled row per
    /// tip. Values use CPython `repr` float formatting, byte-identical to skbio.
    pub fn write_lsmat<W: Write>(&self, w: &mut W) -> io::Result<()> {
        let mut line = String::new();
        for id in &self.ids {
            line.push('\t');
            line.push_str(id);
        }
        line.push('\n');
        w.write_all(line.as_bytes())?;

        for i in 0..self.n {
            line.clear();
            line.push_str(&self.ids[i]);
            let row = &self.data[i * self.n..(i + 1) * self.n];
            for v in row {
                line.push('\t');
                pyfloat::push_repr(&mut line, *v);
            }
            line.push('\n');
            w.write_all(line.as_bytes())?;
        }
        Ok(())
    }
}

/// Patristic distance between every pair of named tips. With `use_length =
/// false`, counts branches instead of summing lengths; missing branch lengths
/// count as 0.
///
/// Mirrors skbio TreeNode.cophenet: one postorder sweep grows each tip's depth
/// edge by edge, and each tip-pair distance is the sum of their two depths
/// captured at their LCA. The float addition order matches skbio bit-for-bit.
pub fn tip_tip_distances(tree: &Tree, use_length: bool) -> Result<TipDistMatrix, TipDistError> {
    let mut ids = Vec::new();
    let mut tip_index = vec![usize::MAX; tree.nodes.len()];
    collect_tips(tree, tree.root, &mut ids, &mut tip_index);

    let n = ids.len();
    if n == 0 {
        return Err(TipDistError::NoTips);
    }
    let mut seen = HashSet::with_capacity(n);
    for name in &ids {
        if !seen.insert(name.as_str()) {
            return Err(TipDistError::DuplicateTip(name.clone()));
        }
    }

    let mut data = vec![0.0_f64; n * n];
    let mut depths = vec![0.0_f64; n];
    let mut range = vec![(usize::MAX, usize::MAX); tree.nodes.len()];

    for &id in &postorder(tree) {
        let node = &tree.nodes[id];
        if node.children.is_empty() {
            let t = tip_index[id];
            range[id] = (t, t + 1);
            continue;
        }

        let mut clades: Vec<(usize, usize)> = Vec::with_capacity(node.children.len());
        for &child in &node.children {
            let (s, e) = range[child];
            if s == usize::MAX {
                continue;
            }
            let inc = if use_length {
                tree.nodes[child].branch_length.unwrap_or(0.0)
            } else {
                1.0
            };
            for d in &mut depths[s..e] {
                *d += inc;
            }
            clades.push((s, e));
        }

        for a in 0..clades.len() {
            let (s1, e1) = clades[a];
            for &(s2, e2) in &clades[a + 1..] {
                for i in s1..e1 {
                    for j in s2..e2 {
                        let v = depths[i] + depths[j];
                        data[i * n + j] = v;
                        data[j * n + i] = v;
                    }
                }
            }
        }

        if let (Some(&(first, _)), Some(&(_, last))) = (clades.first(), clades.last()) {
            range[id] = (first, last);
        }
    }

    Ok(TipDistMatrix { ids, data, n })
}

fn collect_tips(tree: &Tree, id: NodeId, ids: &mut Vec<String>, tip_index: &mut [usize]) {
    let node = &tree.nodes[id];
    if node.children.is_empty() {
        if let Some(name) = &node.name {
            tip_index[id] = ids.len();
            ids.push(name.clone());
        }
        return;
    }
    for &child in &node.children {
        collect_tips(tree, child, ids, tip_index);
    }
}

fn postorder(tree: &Tree) -> Vec<NodeId> {
    let mut order = Vec::with_capacity(tree.nodes.len());
    let mut stack = vec![(tree.root, false)];
    while let Some((id, visited)) = stack.pop() {
        if visited {
            order.push(id);
        } else {
            stack.push((id, true));
            for &child in tree.nodes[id].children.iter().rev() {
                stack.push((child, false));
            }
        }
    }
    order
}

#[cfg(test)]
mod tests {
    use super::*;

    fn render(tree: &Tree, use_length: bool) -> String {
        let dm = tip_tip_distances(tree, use_length).unwrap();
        let mut buf = Vec::new();
        dm.write_lsmat(&mut buf).unwrap();
        String::from_utf8(buf).unwrap()
    }

    #[test]
    fn spec_example_matches_oracle() {
        let tree = Tree::from_newick("((a:1,b:2):0.5,(c:3,d:4):0.6);").unwrap();
        let got = render(&tree, true);
        let want = "\ta\tb\tc\td\n\
                    a\t0.0\t3.0\t5.1\t6.1\n\
                    b\t3.0\t0.0\t6.1\t7.1\n\
                    c\t5.1\t6.1\t0.0\t7.0\n\
                    d\t6.1\t7.1\t7.0\t0.0\n";
        assert_eq!(got, want);
    }

    #[test]
    fn doc_example_with_named_internals() {
        let tree = Tree::from_newick("((a:1,b:2)c:3,(d:4,e:5)f:6)root;").unwrap();
        let dm = tip_tip_distances(&tree, true).unwrap();
        assert_eq!(dm.ids, ["a", "b", "d", "e"]);
        assert_eq!(dm.get(0, 1), 3.0);
        assert_eq!(dm.get(0, 2), 14.0);
        assert_eq!(dm.get(0, 3), 15.0);
        assert_eq!(dm.get(2, 3), 9.0);
    }

    #[test]
    fn branch_counts_when_use_length_false() {
        let tree = Tree::from_newick("((a:1,b:2)c:3,(d:4,e:5)f:6)root;").unwrap();
        let dm = tip_tip_distances(&tree, false).unwrap();
        assert_eq!(dm.get(0, 1), 2.0);
        assert_eq!(dm.get(0, 2), 4.0);
        assert_eq!(dm.get(2, 3), 2.0);
    }

    #[test]
    fn missing_lengths_count_as_zero() {
        let tree = Tree::from_newick("((a,b),(c,d));").unwrap();
        let dm = tip_tip_distances(&tree, true).unwrap();
        for v in &dm.data {
            assert_eq!(*v, 0.0);
        }
    }

    #[test]
    fn duplicate_tip_names_rejected() {
        let tree = Tree::from_newick("((a:1,a:2):0.5,(c:3,d:4):0.6);").unwrap();
        assert!(matches!(
            tip_tip_distances(&tree, true),
            Err(TipDistError::DuplicateTip(_))
        ));
    }
}