mod config;
mod tree;
pub use crate::config::{FastaSequence, NJConfig, MSA};
use crate::tree::NJNode;
use bitvec::prelude::*;
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct TriMat {
pub data: Vec<f64>,
pub node_names: Vec<String>,
}
impl TriMat {
pub fn with_names(names: Vec<String>) -> Self {
let n = names.len();
Self {
data: vec![0.0; n * (n - 1) / 2],
node_names: names,
}
}
pub fn dim(&self) -> usize {
self.node_names.len()
}
fn idx(&self, i: usize, j: usize) -> usize {
let (i, j) = if i > j { (i, j) } else { (j, i) };
i * (i - 1) / 2 + j
}
pub fn get(&self, i: usize, j: usize) -> f64 {
if i == j {
0.0
} else {
self.data[self.idx(i, j)]
}
}
pub fn set(&mut self, i: usize, j: usize, val: f64) {
if i != j {
let index = self.idx(i, j);
self.data[index] = val;
}
}
}
fn neighbor_joining(mut dist: TriMat) -> Result<NJNode, String> {
let n = dist.dim();
if n == 0 {
return Err("Empty distance matrix".to_string());
}
if n == 1 {
return Ok(NJNode::leaf(dist.node_names[0].clone()));
}
let mut active: BitVec<u8, Lsb0> = BitVec::repeat(true, n);
let mut nodes: HashMap<usize, NJNode> = (0..n)
.map(|i| (i, NJNode::leaf(dist.node_names[i].clone())))
.collect();
let mut row_sums: Vec<f64> = (0..n)
.map(|i| (0..n).map(|j| dist.get(i, j)).sum())
.collect();
let mut next_internal = n;
for _ in 0..(n - 2) {
let active_count = active.count_ones() as f64;
let active_ref = &active;
let row_sums_ref = &row_sums;
let dist_ref = &dist;
let pair_opt = (0..n)
.filter(|&i| active_ref[i])
.flat_map(|i| {
(0..i).filter(move |&j| active_ref[j]).map(move |j| {
(
i,
j,
(active_count - 2.0) * dist_ref.get(i, j)
- row_sums_ref[i]
- row_sums_ref[j],
dist_ref.get(i, j),
)
})
})
.min_by(|a, b| a.2.partial_cmp(&b.2).unwrap())
.map(|(i, j, _, d)| (i, j, d));
let (i_min, j_min, dij) = match pair_opt {
Some(t) => t,
None => return Err("Failed to find a pair to join: no active pair found".to_string()),
};
let li = (0.5 * dij + (row_sums[i_min] - row_sums[j_min]) / (2.0 * (active_count - 2.0)))
.max(0.0);
let lj = (dij - li).max(0.0);
let internal_name = format!("Node{}", next_internal);
next_internal += 1;
let left_node = nodes
.remove(&i_min)
.ok_or_else(|| format!("Internal error: node {} missing during join", i_min))?;
let right_node = nodes
.remove(&j_min)
.ok_or_else(|| format!("Internal error: node {} missing during join", j_min))?;
let new_node = NJNode::internal(internal_name, left_node, right_node, li, lj);
nodes.insert(i_min, new_node);
active.set(j_min, false);
(0..n).filter(|&k| active[k] && k != i_min).for_each(|k| {
let dik = dist.get(i_min, k);
let djk = dist.get(j_min, k);
let d_new = 0.5 * (dik + djk - dij);
row_sums[i_min] = row_sums[i_min] - dik - djk + d_new;
row_sums[k] = row_sums[k] - dik - djk + d_new;
dist.set(i_min, k, d_new);
});
row_sums[j_min] = 0.0;
}
let remaining: Vec<usize> = (0..n).filter(|&i| active[i]).collect();
if remaining.len() != 2 {
return Err(format!(
"Expected 2 remaining active nodes, found {}. Input may be malformed.",
remaining.len()
));
}
let i = remaining[0];
let j = remaining[1];
let dij = dist.get(i, j);
let root_name = format!("Node{}", next_internal);
let left = nodes
.remove(&i)
.ok_or_else(|| format!("Internal error: remaining node {} not found", i))?;
let right = nodes
.remove(&j)
.ok_or_else(|| format!("Internal error: remaining node {} not found", j))?;
Ok(NJNode::internal(
root_name,
left,
right,
dij / 2.0,
dij / 2.0,
))
}
fn dist_from_msa(msa: &[FastaSequence]) -> TriMat {
let n = msa.len();
let names: Vec<String> = msa.into_iter().map(|fs| fs.header.clone()).collect();
let mut dist = TriMat::with_names(names);
(0..n).for_each(|i| {
(0..i).for_each(|j| {
let (diffs, valid) = msa[i]
.sequence
.chars()
.zip(msa[j].sequence.chars())
.filter(|(a, b)| *a != '-' && *b != '-')
.fold((0, 0), |(d, v), (a, b)| {
(d + if a != b { 1 } else { 0 }, v + 1)
});
let d = if valid > 0 {
diffs as f64 / valid as f64
} else {
0.0
};
dist.set(i, j, d);
})
});
dist
}
pub trait FastaReader {
fn from_unnamed_sequences(sequences: Vec<String>) -> Self;
}
impl FastaReader for MSA {
fn from_unnamed_sequences(sequences: Vec<String>) -> Self {
sequences
.into_iter()
.enumerate()
.map(|(i, s)| FastaSequence {
header: format!("Seq{}", i),
sequence: s,
})
.collect()
}
}
pub fn nj(conf: NJConfig) -> Result<String, String> {
let dist = dist_from_msa(&conf.msa);
let tree = neighbor_joining(dist).map_err(|e| format!("neighbor-joining failed: {e}"))?;
let newick = format!("{};", tree.to_newick(conf.hide_internal));
Ok(newick)
}
#[cfg(test)]
mod tests {
use super::*;
fn collect_leaf_names(node: &NJNode) -> Vec<String> {
match (&node.left, &node.right) {
(Some(left), Some(right)) => {
let mut l = collect_leaf_names(left);
let mut r = collect_leaf_names(right);
l.append(&mut r);
l
}
(None, None) => vec![node.name.clone()],
_ => unreachable!("panic: NJNode should be either leaf or internal"),
}
}
#[test]
fn test_trimatrix_set_get_and_dim() {
let mut m = TriMat::with_names(vec!["A".into(), "B".into(), "C".into()]);
assert_eq!(m.dim(), 3);
assert_eq!(m.get(0, 1), 0.0);
m.set(0, 1, 0.25);
m.set(2, 1, 0.75);
assert!((m.get(0, 1) - 0.25).abs() < 1e-12);
assert!((m.get(1, 0) - 0.25).abs() < 1e-12);
assert!((m.get(2, 1) - 0.75).abs() < 1e-12);
assert!((m.get(1, 2) - 0.75).abs() < 1e-12);
assert_eq!(m.get(1, 1), 0.0);
}
#[test]
fn test_tri_from_msa_basic() {
let seqs: Vec<String> = vec!["ACG".into(), "ATG".into(), "A-G".into()];
let msa = MSA::from_unnamed_sequences(seqs);
let mat = dist_from_msa(&msa);
assert_eq!(mat.node_names, vec!["Seq0", "Seq1", "Seq2"]);
assert!((mat.get(0, 1) - (1.0 / 3.0)).abs() < 1e-12);
assert!((mat.get(0, 2) - 0.0).abs() < 1e-12);
assert!((mat.get(1, 2) - 0.0).abs() < 1e-12);
}
#[test]
fn test_tri_from_msa_no_overlap() {
let msa: MSA = vec![
FastaSequence {
header: "X".into(),
sequence: "A--".into(),
},
FastaSequence {
header: "Y".into(),
sequence: "--A".into(),
},
];
let mat = dist_from_msa(&msa);
assert_eq!(mat.node_names, vec!["X", "Y"]);
assert!((mat.get(0, 1) - 0.0).abs() < 1e-12);
}
#[test]
fn test_neighbor_joining_two_taxa() {
let mut m = TriMat::with_names(vec!["A".into(), "B".into()]);
m.set(0, 1, 0.6);
let tree = neighbor_joining(m).expect("NJ should succeed for two taxa");
let leaves = collect_leaf_names(&tree);
let mut leaves_sorted = leaves.clone();
leaves_sorted.sort();
assert_eq!(leaves_sorted, vec!["A".to_string(), "B".to_string()]);
assert!((tree.left_len - 0.3).abs() < 1e-12);
assert!((tree.right_len - 0.3).abs() < 1e-12);
}
#[test]
fn test_neighbor_joining_three_taxa_preserves_leaves() {
let mut m = TriMat::with_names(vec!["A".into(), "B".into(), "C".into()]);
m.set(0, 1, 0.2);
m.set(0, 2, 0.2);
m.set(1, 2, 0.2);
let tree = neighbor_joining(m).expect("NJ should succeed for three taxa");
let mut leaves = collect_leaf_names(&tree);
leaves.sort();
assert_eq!(
leaves,
vec!["A".to_string(), "B".to_string(), "C".to_string()]
);
fn check_nonneg(node: &NJNode) {
assert!(node.left_len >= -1e-12, "left_len negative");
assert!(node.right_len >= -1e-12, "right_len negative");
if let (Some(l), Some(r)) = (&node.left, &node.right) {
check_nonneg(l);
check_nonneg(r);
}
}
check_nonneg(&tree);
}
#[test]
fn test_to_newick_produces_valid_format() {
let left = NJNode::leaf("L".into());
let right = NJNode::leaf("R".into());
let root = NJNode::internal("root".into(), left, right, 0.1234, 0.5678);
let s_hidden = root.to_newick(true);
assert!(s_hidden.contains(":0.123"));
assert!(s_hidden.contains(":0.568"));
let s_named = root.to_newick(false);
assert!(s_named.ends_with("root"));
}
}