pub mod alphabet;
pub mod config;
pub mod distance_matrix;
pub mod models;
pub mod msa;
pub mod nj;
pub mod tree;
use bitvec::prelude::{BitVec, Lsb0, bitvec};
use std::collections::HashMap;
use crate::alphabet::{Alphabet, AlphabetEncoding, DNA, Protein};
use crate::config::SubstitutionModel;
pub use crate::config::{DistConfig, MSA, NJConfig, SequenceObject};
use crate::distance_matrix::DistMat;
pub use crate::distance_matrix::DistanceResult;
use crate::models::{JukesCantor, Kimura2P, ModelCalculation, PDiff, Poisson};
use crate::tree::{NameOrSupport, TreeNode};
fn bitset_of(
node: &TreeNode,
idx: &HashMap<String, usize>,
out: &mut BitVec<u8, Lsb0>,
) -> Result<(), String> {
match &node.children {
None => match &node.label {
Some(NameOrSupport::Name(name)) => {
let i = idx[name];
out.set(i, true);
Ok(())
}
_ => Err("Leaf node without a name label".into()),
},
Some([l, r]) => {
bitset_of(l, idx, out)?;
bitset_of(r, idx, out)?;
Ok(())
}
}
}
fn count_clades(
tree: &TreeNode,
idx: &HashMap<String, usize>,
n_taxa: usize,
counter: &mut HashMap<Vec<u8>, usize>,
) -> Result<(), String> {
if let Some([l, r]) = &tree.children {
let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
bitset_of(tree, idx, &mut bv)?;
let n = bv.count_ones();
if n > 1 && n < n_taxa {
counter
.entry(bv.as_raw_slice().to_vec())
.and_modify(|c| *c += 1)
.or_insert(1);
}
count_clades(l, idx, n_taxa, counter)?;
count_clades(r, idx, n_taxa, counter)?;
}
Ok(())
}
fn bootstrap_clade_counts<A: AlphabetEncoding, M: ModelCalculation<A>>(
msa: &MSA<A>,
n_bootstrap_samples: usize,
on_progress: Option<&dyn Fn(usize, usize)>,
) -> Result<Option<HashMap<Vec<u8>, usize>>, String> {
if n_bootstrap_samples == 0 {
return Ok(None);
}
let idx_map: HashMap<String, usize> = msa.to_index_map();
let mut counter = HashMap::new();
for i in 0..n_bootstrap_samples {
let tree = msa
.bootstrap()?
.into_dist::<M>()
.neighbor_joining()
.expect("NJ bootstrap iteration failed");
count_clades(&tree, &idx_map, msa.n_sequences, &mut counter)?;
if let Some(cb) = on_progress {
cb(i + 1, n_bootstrap_samples);
}
}
Ok(Some(counter))
}
fn add_bootstrap_to_tree(
node: &mut TreeNode,
idx: &HashMap<String, usize>,
n_taxa: usize,
counts: &HashMap<Vec<u8>, usize>,
) -> Result<(), String> {
if node.children.is_some() {
let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
bitset_of(node, idx, &mut bv)?;
let n = bv.count_ones();
if n > 1 && n < n_taxa {
if let Some(c) = counts.get(&bv.as_raw_slice().to_vec()) {
node.label = Some(NameOrSupport::Support(*c));
}
}
if let Some([l, r]) = &mut node.children {
add_bootstrap_to_tree(l, idx, n_taxa, counts)?;
add_bootstrap_to_tree(r, idx, n_taxa, counts)?;
}
}
Ok(())
}
fn validate_msa(msa: &[SequenceObject]) -> Result<(), String> {
if msa.is_empty() {
return Err("Input MSA is empty".into());
}
let expected_len = msa[0].sequence.len();
if expected_len == 0 {
return Err("Sequences must not be empty".into());
}
for s in msa {
if s.sequence.len() != expected_len {
return Err(format!(
"All sequences must have the same length. Expected {}, got {} for '{}'",
expected_len,
s.sequence.len(),
s.identifier
));
}
}
Ok(())
}
fn detect_alphabet(msa: &[SequenceObject]) -> Result<Alphabet, String> {
let mut is_protein = false;
for seq in msa {
for c in seq.sequence.bytes() {
match c.to_ascii_uppercase() {
b'A' | b'C' | b'G' | b'T' | b'U' | b'N' | b'-' => { }
_ => {
is_protein = true;
break;
}
}
}
if is_protein {
break;
}
}
Ok(if is_protein {
Alphabet::Protein
} else {
Alphabet::DNA
})
}
fn run_distance_matrix<A, M>(msa: MSA<A>) -> Result<DistanceResult, String>
where
A: AlphabetEncoding,
M: ModelCalculation<A>,
{
Ok(msa.into_dist::<M>().into_result())
}
fn run_average_distance<A, M>(msa: MSA<A>) -> Result<f64, String>
where
A: AlphabetEncoding,
M: ModelCalculation<A>,
{
Ok(msa.into_dist::<M>().average())
}
fn run_nj<A, M>(
msa: MSA<A>,
n_bootstrap_samples: usize,
on_progress: Option<&dyn Fn(usize, usize)>,
) -> Result<String, String>
where
A: AlphabetEncoding,
M: ModelCalculation<A>,
{
let clade_counts = bootstrap_clade_counts::<A, M>(&msa, n_bootstrap_samples, on_progress)?;
let mut main_tree = msa.into_dist::<M>().neighbor_joining()?;
let newick = match clade_counts {
Some(counts) => {
let main_idx_map: HashMap<String, usize> = msa.to_index_map();
add_bootstrap_to_tree(&mut main_tree, &main_idx_map, msa.n_sequences, &counts)?;
main_tree.to_newick()
}
None => main_tree.to_newick(),
};
Ok(newick)
}
pub fn nj(
conf: NJConfig,
on_progress: Option<Box<dyn Fn(usize, usize)>>,
) -> Result<String, String> {
let cb = on_progress.as_deref();
validate_msa(&conf.msa)?;
let alphabet = detect_alphabet(&conf.msa)?;
match alphabet {
Alphabet::DNA => {
let msa =
MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
match conf.substitution_model {
SubstitutionModel::PDiff => {
run_nj::<DNA, PDiff>(msa, conf.n_bootstrap_samples, cb)
}
SubstitutionModel::JukesCantor => {
run_nj::<DNA, JukesCantor>(msa, conf.n_bootstrap_samples, cb)
}
SubstitutionModel::Kimura2P => {
run_nj::<DNA, Kimura2P>(msa, conf.n_bootstrap_samples, cb)
}
SubstitutionModel::Poisson => {
Err("Poisson is a protein model; cannot use with DNA".into())
}
}
}
Alphabet::Protein => {
let msa =
MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
match conf.substitution_model {
SubstitutionModel::Poisson => {
run_nj::<Protein, Poisson>(msa, conf.n_bootstrap_samples, cb)
}
SubstitutionModel::PDiff => {
run_nj::<Protein, PDiff>(msa, conf.n_bootstrap_samples, cb)
}
SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
Err("Selected model is for DNA; cannot use with Protein".into())
}
}
}
}
}
pub fn distance_matrix(conf: DistConfig) -> Result<DistanceResult, String> {
validate_msa(&conf.msa)?;
let alphabet = detect_alphabet(&conf.msa)?;
match alphabet {
Alphabet::DNA => {
let msa =
MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
match conf.substitution_model {
SubstitutionModel::PDiff => run_distance_matrix::<DNA, PDiff>(msa),
SubstitutionModel::JukesCantor => run_distance_matrix::<DNA, JukesCantor>(msa),
SubstitutionModel::Kimura2P => run_distance_matrix::<DNA, Kimura2P>(msa),
SubstitutionModel::Poisson => {
Err("Poisson is a protein model; cannot use with DNA".into())
}
}
}
Alphabet::Protein => {
let msa =
MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
match conf.substitution_model {
SubstitutionModel::Poisson => run_distance_matrix::<Protein, Poisson>(msa),
SubstitutionModel::PDiff => run_distance_matrix::<Protein, PDiff>(msa),
SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
Err("Selected model is for DNA; cannot use with Protein".into())
}
}
}
}
}
pub fn average_distance(conf: DistConfig) -> Result<f64, String> {
validate_msa(&conf.msa)?;
let alphabet = detect_alphabet(&conf.msa)?;
match alphabet {
Alphabet::DNA => {
let msa =
MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
match conf.substitution_model {
SubstitutionModel::PDiff => run_average_distance::<DNA, PDiff>(msa),
SubstitutionModel::JukesCantor => run_average_distance::<DNA, JukesCantor>(msa),
SubstitutionModel::Kimura2P => run_average_distance::<DNA, Kimura2P>(msa),
SubstitutionModel::Poisson => {
Err("Poisson is a protein model; cannot use with DNA".into())
}
}
}
Alphabet::Protein => {
let msa =
MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
match conf.substitution_model {
SubstitutionModel::Poisson => run_average_distance::<Protein, Poisson>(msa),
SubstitutionModel::PDiff => run_average_distance::<Protein, PDiff>(msa),
SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
Err("Selected model is for DNA; cannot use with Protein".into())
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DistConfig;
use crate::models::SubstitutionModel;
#[test]
fn test_nj_wrapper_simple_tree() {
let sequences = vec![
SequenceObject {
identifier: "A".into(),
sequence: "ACGTCG".into(),
},
SequenceObject {
identifier: "B".into(),
sequence: "ACG-GC".into(),
},
];
let conf = NJConfig {
msa: sequences,
n_bootstrap_samples: 0,
substitution_model: SubstitutionModel::PDiff,
};
let newick = nj(conf, None).expect("NJ failed");
assert_eq!(newick, "(A:0.167,B:0.167);");
}
#[test]
fn test_nj_wrapper_adds_semicolon() {
let sequences = vec![
SequenceObject {
identifier: "Seq0".into(),
sequence: "A".into(),
},
SequenceObject {
identifier: "Seq1".into(),
sequence: "A".into(),
},
];
let conf = NJConfig {
msa: sequences,
n_bootstrap_samples: 0,
substitution_model: SubstitutionModel::PDiff,
};
let out = nj(conf, None).unwrap();
assert!(out.ends_with(';'));
}
#[test]
fn test_nj_deterministic_order() {
let sequences = vec![
SequenceObject {
identifier: "Seq0".into(),
sequence: "ACGTCG".into(),
},
SequenceObject {
identifier: "Seq1".into(),
sequence: "ACG-GC".into(),
},
SequenceObject {
identifier: "Seq2".into(),
sequence: "ACGCGT".into(),
},
];
let conf = NJConfig {
msa: sequences,
n_bootstrap_samples: 0,
substitution_model: SubstitutionModel::PDiff,
};
let t1 = nj(conf.clone(), None).unwrap();
let t2 = nj(conf, None).unwrap();
assert_eq!(t1, t2);
}
#[test]
fn test_nj_wrapper_empty_msa() {
let conf = NJConfig {
msa: vec![],
n_bootstrap_samples: 0,
substitution_model: SubstitutionModel::PDiff,
};
let result = nj(conf, None);
assert!(result.is_err());
}
#[test]
fn test_nj_wrapper_incorrect_model_for_alphabet() {
let sequences = vec![
SequenceObject {
identifier: "Seq0".into(),
sequence: "ACGTCG".into(),
},
SequenceObject {
identifier: "Seq1".into(),
sequence: "ACG-GC".into(),
},
];
let conf = NJConfig {
msa: sequences,
n_bootstrap_samples: 0,
substitution_model: SubstitutionModel::Poisson, };
let result = nj(conf, None);
assert!(result.is_err());
}
#[test]
fn test_nj_wrapper_incorrect_model_for_protein() {
let sequences = vec![
SequenceObject {
identifier: "Seq0".into(),
sequence: "ACDEFGH".into(),
},
SequenceObject {
identifier: "Seq1".into(),
sequence: "ACD-FGH".into(),
},
];
let conf = NJConfig {
msa: sequences,
n_bootstrap_samples: 0,
substitution_model: SubstitutionModel::JukesCantor, };
let result = nj(conf, None);
assert!(result.is_err());
}
fn dist_conf(pairs: &[(&str, &str)], model: SubstitutionModel) -> DistConfig {
DistConfig {
msa: pairs
.iter()
.map(|(id, seq)| SequenceObject {
identifier: id.to_string(),
sequence: seq.to_string(),
})
.collect(),
substitution_model: model,
}
}
#[test]
fn test_distance_matrix_names_and_shape() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
let result = distance_matrix(conf).unwrap();
assert_eq!(result.names, vec!["A", "B"]);
assert_eq!(result.matrix.len(), 2);
assert_eq!(result.matrix[0].len(), 2);
assert_eq!(result.matrix[1].len(), 2);
}
#[test]
fn test_distance_matrix_diagonal_zero() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")], SubstitutionModel::PDiff);
let result = distance_matrix(conf).unwrap();
for i in 0..3 {
assert_eq!(result.matrix[i][i], 0.0);
}
}
#[test]
fn test_distance_matrix_symmetric() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")], SubstitutionModel::PDiff);
let result = distance_matrix(conf).unwrap();
for i in 0..3 {
for j in 0..3 {
assert_eq!(result.matrix[i][j], result.matrix[j][i]);
}
}
}
#[test]
fn test_distance_matrix_pdiff_known_value() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
let result = distance_matrix(conf).unwrap();
assert!((result.matrix[0][1] - 0.25).abs() < 1e-12);
assert!((result.matrix[1][0] - 0.25).abs() < 1e-12);
}
#[test]
fn test_distance_matrix_identical_sequences_zero() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGT")], SubstitutionModel::PDiff);
let result = distance_matrix(conf).unwrap();
assert_eq!(result.matrix[0][1], 0.0);
}
#[test]
fn test_distance_matrix_jukes_cantor_dna() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")], SubstitutionModel::JukesCantor);
let result = distance_matrix(conf).unwrap();
let expected = -0.75_f64 * (1.0_f64 - (4.0_f64 / 3.0) * 0.25).ln();
assert!((result.matrix[0][1] - expected).abs() < 1e-10);
}
#[test]
fn test_distance_matrix_kimura2p_dna() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")], SubstitutionModel::Kimura2P);
let result = distance_matrix(conf).unwrap();
assert_eq!(result.names, vec!["A", "B", "C"]);
assert!(result.matrix[0][0] == 0.0);
}
#[test]
fn test_distance_matrix_poisson_protein() {
let conf = dist_conf(&[("A", "ACDEFGH"), ("B", "ACDEFGK")], SubstitutionModel::Poisson);
let result = distance_matrix(conf).unwrap();
let expected = -(1.0_f64 - 1.0 / 7.0).ln();
assert!((result.matrix[0][1] - expected).abs() < 1e-10);
}
#[test]
fn test_distance_matrix_pdiff_protein() {
let conf = dist_conf(&[("A", "ACDEFGH"), ("B", "ACDEFGK")], SubstitutionModel::PDiff);
let result = distance_matrix(conf).unwrap();
assert!((result.matrix[0][1] - 1.0 / 7.0).abs() < 1e-12);
}
#[test]
fn test_distance_matrix_empty_msa_errors() {
let conf = DistConfig { msa: vec![], substitution_model: SubstitutionModel::PDiff };
assert!(distance_matrix(conf).is_err());
}
#[test]
fn test_distance_matrix_incompatible_model_errors() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::Poisson);
assert!(distance_matrix(conf).is_err());
let conf = dist_conf(&[("A", "ACDEFGH"), ("B", "ACDEFGK")], SubstitutionModel::JukesCantor);
assert!(distance_matrix(conf).is_err());
}
#[test]
fn test_average_distance_identical_sequences_zero() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGT")], SubstitutionModel::PDiff);
let avg = average_distance(conf).unwrap();
assert_eq!(avg, 0.0);
}
#[test]
fn test_average_distance_two_taxa_equals_pairwise() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
let avg = average_distance(conf).unwrap();
assert!((avg - 0.25).abs() < 1e-12);
}
#[test]
fn test_average_distance_three_taxa_known_value() {
let conf = dist_conf(
&[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
SubstitutionModel::PDiff,
);
let avg = average_distance(conf).unwrap();
assert!((avg - 1.0 / 3.0).abs() < 1e-12);
}
#[test]
fn test_average_distance_jukes_cantor_dna() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::JukesCantor);
let avg = average_distance(conf).unwrap();
let expected = -0.75_f64 * (1.0_f64 - (4.0_f64 / 3.0) * 0.25).ln();
assert!((avg - expected).abs() < 1e-10);
}
#[test]
fn test_average_distance_empty_msa_errors() {
let conf = DistConfig { msa: vec![], substitution_model: SubstitutionModel::PDiff };
assert!(average_distance(conf).is_err());
}
#[test]
fn test_average_distance_incompatible_model_errors() {
let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::Poisson);
assert!(average_distance(conf).is_err());
}
#[test]
fn test_detect_alphabet_dna() {
let msa = vec![
SequenceObject {
identifier: "Seq0".into(),
sequence: "ACGTACGT".into(),
},
SequenceObject {
identifier: "Seq1".into(),
sequence: "ACG-ACGT".into(),
},
];
let alphabet = detect_alphabet(&msa).expect("detection failed");
assert_eq!(alphabet, Alphabet::DNA);
}
#[test]
fn test_detect_alphabet_protein() {
let msa = vec![
SequenceObject {
identifier: "Seq0".into(),
sequence: "ACDEFGHIK".into(),
},
SequenceObject {
identifier: "Seq1".into(),
sequence: "ACD-FGHIK".into(),
},
];
let alphabet = detect_alphabet(&msa).expect("detection failed");
assert_eq!(alphabet, Alphabet::Protein);
}
}