use crate::MSA;
use crate::alphabet::AlphabetEncoding;
use crate::models::{ModelCalculation, pairwise_distance};
use crate::nj::NJState;
use crate::tree::TreeNode;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug)]
pub struct DistMat {
pub data: Vec<f64>,
pub names: Vec<String>,
}
impl DistMat {
pub fn empty_with_names(names: Vec<String>) -> Self {
let n = names.len();
let n_entries = n
.checked_mul(n.saturating_sub(1))
.expect("distance matrix too large: n * (n-1) overflows usize")
/ 2;
Self {
data: vec![0.0; n_entries],
names,
}
}
pub fn dim(&self) -> usize {
self.names.len()
}
fn idx(&self, i: usize, j: usize) -> usize {
let (i, j) = if i > j { (i, j) } else { (j, i) };
i * (i - 1) / 2 + j
}
pub fn get(&self, i: usize, j: usize) -> f64 {
if i == j {
0.0
} else {
self.data[self.idx(i, j)]
}
}
pub fn set(&mut self, i: usize, j: usize, val: f64) {
if i != j {
let index = self.idx(i, j);
self.data[index] = val;
}
}
pub fn from_msa<M, A>(msa: &MSA<A>) -> DistMat
where
M: ModelCalculation<A> + Send + Sync,
A: AlphabetEncoding + Sync,
A::Symbol: Send + Sync,
{
let n = msa.n_sequences;
let mut dist = DistMat::empty_with_names(msa.identifiers.clone());
#[cfg(feature = "parallel")]
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let pairs: Vec<(usize, usize, f64)> = (0..n)
.flat_map(|i| (0..i).map(move |j| (i, j)))
.collect::<Vec<_>>()
.into_par_iter()
.map(|(i, j)| {
let d = pairwise_distance::<M, A>(&msa.sequences[i], &msa.sequences[j]);
(i, j, d)
})
.collect();
for (i, j, d) in pairs {
dist.set(i, j, d);
}
}
#[cfg(not(feature = "parallel"))]
{
for i in 0..n {
for j in 0..i {
dist.set(i, j, pairwise_distance::<M, A>(&msa.sequences[i], &msa.sequences[j]));
}
}
}
dist
}
pub fn into_result(self) -> DistanceResult {
let n = self.dim();
let matrix = (0..n)
.map(|i| (0..n).map(|j| self.get(i, j)).collect())
.collect();
DistanceResult { names: self.names, matrix }
}
pub fn to_result(&self) -> DistanceResult {
let n = self.dim();
let matrix = (0..n)
.map(|i| (0..n).map(|j| self.get(i, j)).collect())
.collect();
DistanceResult { names: self.names.clone(), matrix }
}
pub fn average(&self) -> f64 {
if self.data.is_empty() {
return 0.0;
}
self.data.iter().sum::<f64>() / self.data.len() as f64
}
pub fn neighbor_joining(mut self: DistMat) -> Result<TreeNode, String> {
NJState::new(&mut self).run()
}
}
#[derive(Clone, Debug, PartialEq, ts_rs::TS, Serialize, Deserialize)]
#[ts(export, export_to = "../../wasm/types/lib_types.ts")]
pub struct DistanceResult {
pub names: Vec<String>,
pub matrix: Vec<Vec<f64>>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alphabet::DNA;
use crate::models::PDiff;
use crate::msa::MSA;
use crate::tree::NameOrSupport;
fn collect_leaf_names(node: &TreeNode) -> Vec<String> {
match &node.children {
Some([left, right]) => {
let mut l = collect_leaf_names(left);
let mut r = collect_leaf_names(right);
l.append(&mut r);
l
}
None => match node.label {
Some(NameOrSupport::Name(ref name)) => vec![name.clone()],
_ => vec![node.identifier.to_string()],
},
}
}
#[test]
fn test_dist_set_get_and_dim() {
let mut m = DistMat::empty_with_names(vec!["A".into(), "B".into(), "C".into()]);
assert_eq!(m.dim(), 3);
assert_eq!(m.get(0, 1), 0.0);
m.set(0, 1, 0.25);
m.set(2, 1, 0.75);
assert!((m.get(0, 1) - 0.25).abs() < 1e-12);
assert!((m.get(1, 0) - 0.25).abs() < 1e-12);
assert!((m.get(2, 1) - 0.75).abs() < 1e-12);
assert!((m.get(1, 2) - 0.75).abs() < 1e-12);
assert_eq!(m.get(1, 1), 0.0);
}
#[test]
fn test_dist_empty_with_names_zero_dim() {
let m = DistMat::empty_with_names(vec![]);
assert_eq!(m.dim(), 0);
assert!(m.data.is_empty());
}
#[test]
fn test_dist_from_msa_basic() {
let seqs: Vec<String> = vec!["ACG".into(), "ATG".into(), "A-G".into()];
let msa = MSA::<DNA>::from_unnamed_sequences(seqs).unwrap();
let mat = msa.into_dist::<PDiff>();
assert_eq!(mat.names, vec!["Seq0", "Seq1", "Seq2"]);
assert!((mat.get(0, 1) - (1.0 / 3.0)).abs() < 1e-12);
assert!((mat.get(0, 2) - 0.0).abs() < 1e-12);
assert!((mat.get(1, 2) - 0.0).abs() < 1e-12);
}
#[test]
fn test_dist_from_msa_gap_positions_excluded_from_denominator() {
let seqs: Vec<String> = vec!["AT-".into(), "ACG".into()];
let msa = MSA::<DNA>::from_unnamed_sequences(seqs).unwrap();
let mat = msa.into_dist::<PDiff>();
assert!((mat.get(0, 1) - (1.0 / 2.0)).abs() < 1e-12);
}
#[test]
fn test_dist_from_msa_no_overlap() {
let seqs = vec![("X".into(), "A--".into()), ("Y".into(), "--A".into())];
let msa = MSA::<DNA>::from_iter(seqs.into_iter());
let mat = msa.into_dist::<PDiff>();
assert_eq!(mat.names, vec!["X", "Y"]);
assert!((mat.get(0, 1) - 0.0).abs() < 1e-12);
}
#[test]
fn test_dist_symmetry_invariant() {
let mut m = DistMat::empty_with_names(vec!["A".into(), "B".into(), "C".into()]);
m.set(2, 0, 0.25);
assert_eq!(m.get(0, 2), 0.25);
assert_eq!(m.get(2, 0), 0.25);
}
#[test]
fn test_neighbor_joining_one_taxon() {
let m = DistMat::empty_with_names(vec!["A".into()]);
let tree = m
.neighbor_joining()
.expect("NJ should succeed for one taxon");
assert!(tree.children.is_none());
assert!(matches!(tree.label, Some(NameOrSupport::Name(ref s)) if s == "A"));
}
#[test]
fn test_neighbor_joining_empty_matrix_returns_error() {
let m = DistMat::empty_with_names(vec![]);
let err = m.neighbor_joining().unwrap_err();
assert_eq!(err, "Empty distance matrix");
}
#[test]
fn test_neighbor_joining_two_taxa() {
let mut m = DistMat::empty_with_names(vec!["A".into(), "B".into()]);
m.set(0, 1, 0.6);
let tree = m
.neighbor_joining()
.expect("NJ should succeed for two taxa");
let leaves = collect_leaf_names(&tree);
let mut leaves_sorted = leaves.clone();
leaves_sorted.sort();
assert_eq!(leaves_sorted, vec!["A".to_string(), "B".to_string()]);
assert!(tree.children.is_some());
let children = tree.children.as_ref().unwrap();
assert!((children[0].len.unwrap() - 0.3).abs() < 1e-12);
assert!((children[1].len.unwrap() - 0.3).abs() < 1e-12);
}
#[test]
fn test_neighbor_joining_three_taxa_preserves_leaves() {
let mut m = DistMat::empty_with_names(vec!["A".into(), "B".into(), "C".into()]);
m.set(0, 1, 0.2);
m.set(0, 2, 0.2);
m.set(1, 2, 0.2);
let tree = m
.neighbor_joining()
.expect("NJ should succeed for three taxa");
let mut leaves = collect_leaf_names(&tree);
leaves.sort();
assert_eq!(
leaves,
vec!["A".to_string(), "B".to_string(), "C".to_string()]
);
fn check_nonneg(node: &TreeNode) {
match &node.children {
None => return,
Some(children) => {
assert!(children[0].len.unwrap() >= -1e-12, "left_len negative");
assert!(children[1].len.unwrap() >= -1e-12, "right_len negative");
check_nonneg(children[0].as_ref());
check_nonneg(children[1].as_ref());
}
}
}
check_nonneg(&tree);
}
#[test]
fn test_to_newick_produces_valid_format() {
let left = TreeNode::leaf(0, "L".into(), Some(0.1234));
let right = TreeNode::leaf(1, "R".into(), Some(0.5678));
let root = TreeNode::internal(
2,
Some([Box::new(left), Box::new(right)]),
Some(0.0),
Some(85),
);
let newick = root.to_newick();
assert!(newick.contains(":0.123"));
assert!(newick.contains(":0.568"));
assert!(newick.starts_with("("));
assert!(newick.ends_with(")85;"));
}
#[test]
fn test_two_taxa_branch_length_sum() {
let mut m = DistMat::empty_with_names(vec!["A".into(), "B".into()]);
m.set(0, 1, 1.0);
let tree = m.neighbor_joining().unwrap();
let children = tree.children.unwrap();
assert!((children[0].len.unwrap() + children[1].len.unwrap() - 1.0).abs() < 1e-12);
}
}