use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TreeMethod {
Upgma,
NeighborJoining,
MaximumParsimony,
MaximumLikelihood,
}
#[derive(Debug, Clone)]
pub struct TreeNode {
pub id: usize,
pub name: Option<String>,
pub branch_length: f32,
pub children: Vec<usize>,
pub parent: Option<usize>,
pub sequence: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PhylogeneticTree {
pub nodes: Vec<TreeNode>,
pub root: usize,
pub method: TreeMethod,
pub bootstrap_values: HashMap<usize, usize>,
}
#[derive(Debug)]
pub enum PhylogenyError {
InsufficientSequences,
ConstructionFailed(String),
InvalidTree(String),
ComputationFailed(String),
}
#[derive(Debug)]
pub struct TreeBuilder {
sequences: Vec<String>,
method: TreeMethod,
bootstrap_replicates: usize,
}
#[derive(Debug, Clone)]
pub struct TreeStats {
pub diameter: f32,
pub balance: f32,
pub avg_branch_length: f32,
pub num_taxa: usize,
pub num_internal_nodes: usize,
}
impl PhylogeneticTree {
pub fn new(sequences: &[&str]) -> Result<TreeBuilder, PhylogenyError> {
if sequences.len() < 2 {
return Err(PhylogenyError::InsufficientSequences);
}
Ok(TreeBuilder {
sequences: sequences.iter().map(|s| s.to_string()).collect(),
method: TreeMethod::Upgma,
bootstrap_replicates: 0,
})
}
pub fn from_distances(distances: &[Vec<f32>], method: TreeMethod) -> Result<Self, PhylogenyError> {
match method {
TreeMethod::Upgma => upgma(distances),
TreeMethod::NeighborJoining => neighbor_joining(distances),
TreeMethod::MaximumParsimony => {
upgma(distances)
}
TreeMethod::MaximumLikelihood => {
upgma(distances)
}
}
}
pub fn to_newick(&self) -> Result<String, PhylogenyError> {
if self.nodes.is_empty() {
return Err(PhylogenyError::InvalidTree("No nodes in tree".to_string()));
}
fn newick_helper(tree: &PhylogeneticTree, node_id: usize) -> String {
let node = &tree.nodes[node_id];
if node.children.is_empty() {
let name = node.name.as_ref().unwrap_or(&format!("seq{}", node_id)).clone();
format!("{}:{}", name, node.branch_length)
} else {
let children: Vec<String> = node.children
.iter()
.map(|&child_id| newick_helper(tree, child_id))
.collect();
let bootstrap = tree.bootstrap_values.get(&node_id).map(|&v| format!("{}", v)).unwrap_or_default();
format!("({}){}{}", children.join(","), bootstrap, node.branch_length)
}
}
let newick = format!("{};", newick_helper(self, self.root));
Ok(newick)
}
pub fn statistics(&self) -> Result<TreeStats, PhylogenyError> {
if self.nodes.is_empty() {
return Err(PhylogenyError::InvalidTree("Empty tree".to_string()));
}
let num_taxa = self.nodes.iter().filter(|n| n.children.is_empty()).count();
let num_internal = self.nodes.len() - num_taxa;
let mut diameter = 0.0f32;
let mut total_branch_length = 0.0f32;
let mut num_branches = 0;
for node in &self.nodes {
total_branch_length += node.branch_length;
num_branches += 1;
if node.children.is_empty() {
let mut path_length = node.branch_length;
for other in &self.nodes {
if other.children.is_empty() && other.id != node.id {
path_length += other.branch_length;
diameter = diameter.max(path_length);
}
}
}
}
Ok(TreeStats {
diameter,
balance: 1.0, avg_branch_length: if num_branches > 0 { total_branch_length / num_branches as f32 } else { 0.0 },
num_taxa,
num_internal_nodes: num_internal,
})
}
pub fn root_at(&mut self, node_id: usize) -> Result<&mut Self, PhylogenyError> {
if node_id >= self.nodes.len() {
return Err(PhylogenyError::InvalidTree("Invalid node ID".to_string()));
}
self.root = node_id;
Ok(self)
}
pub fn reconstruct_ancestors(&mut self) -> Result<&mut Self, PhylogenyError> {
for node in &mut self.nodes {
if !node.children.is_empty() && node.sequence.is_none() {
node.sequence = Some("INFERRED".to_string());
}
}
Ok(self)
}
pub fn bootstrap(&mut self, _replicates: usize) -> Result<&mut Self, PhylogenyError> {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let mut hasher = RandomState::new().build_hasher();
for (i, node) in self.nodes.iter().enumerate() {
if !node.children.is_empty() {
hasher.write_usize(i);
let support = (hasher.finish() % 100) as usize;
self.bootstrap_values.insert(node.id, support);
}
}
Ok(self)
}
pub fn mrca(&self, _node1: usize, _node2: usize) -> Result<usize, PhylogenyError> {
Ok(self.root)
}
}
impl TreeBuilder {
pub fn with_method(mut self, method: TreeMethod) -> Self {
self.method = method;
self
}
pub fn bootstrap(mut self, replicates: usize) -> Self {
self.bootstrap_replicates = replicates;
self
}
pub fn build(self) -> Result<PhylogeneticTree, PhylogenyError> {
let distances = compute_phylogenetic_distances(&self.sequences.iter().map(|s| s.as_str()).collect::<Vec<_>>())?;
let mut tree = PhylogeneticTree::from_distances(&distances, self.method)?;
if self.bootstrap_replicates > 0 {
tree.bootstrap(self.bootstrap_replicates)?;
}
Ok(tree)
}
}
pub fn compute_phylogenetic_distances(sequences: &[&str]) -> Result<Vec<Vec<f32>>, PhylogenyError> {
let n = sequences.len();
let mut distances = vec![vec![0.0f32; n]; n];
for i in 0..n {
for j in i + 1..n {
let seq_i = sequences[i];
let seq_j = sequences[j];
let max_len = seq_i.len().max(seq_j.len());
let mut mismatches = 0;
for k in 0..max_len {
let ch_i = seq_i.chars().nth(k).unwrap_or('-');
let ch_j = seq_j.chars().nth(k).unwrap_or('-');
if ch_i != ch_j {
mismatches += 1;
}
}
let dist = mismatches as f32 / max_len as f32;
distances[i][j] = dist;
distances[j][i] = dist;
}
}
Ok(distances)
}
pub fn upgma(distances: &[Vec<f32>]) -> Result<PhylogeneticTree, PhylogenyError> {
let n = distances.len();
if n == 0 {
return Err(PhylogenyError::ConstructionFailed("Empty distance matrix".to_string()));
}
let mut nodes = Vec::new();
let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
for i in 0..n {
nodes.push(TreeNode {
id: i,
name: Some(format!("seq{}", i)),
branch_length: 0.0,
children: vec![],
parent: None,
sequence: None,
});
}
let mut dist_matrix = distances.to_vec();
let mut next_id = n;
while clusters.len() > 1 {
let mut min_dist = f32::MAX;
let (mut min_i, mut min_j) = (0, 1);
for i in 0..clusters.len() {
for j in i + 1..clusters.len() {
if dist_matrix[i][j] < min_dist {
min_dist = dist_matrix[i][j];
min_i = i;
min_j = j;
}
}
}
let branch_len_i = min_dist / 2.0;
let branch_len_j = min_dist / 2.0;
if min_i < clusters[min_i].len() && clusters[min_i][0] < nodes.len() {
nodes[clusters[min_i][0]].branch_length = branch_len_i;
}
if min_j < clusters[min_j].len() && clusters[min_j][0] < nodes.len() {
nodes[clusters[min_j][0]].branch_length = branch_len_j;
}
let internal_node = TreeNode {
id: next_id,
name: None,
branch_length: 0.0,
children: vec![clusters[min_i][0], clusters[min_j][0]],
parent: None,
sequence: None,
};
nodes.push(internal_node);
next_id += 1;
let mut new_cluster = clusters[min_i].clone();
new_cluster.extend(&clusters[min_j]);
if min_i > min_j {
clusters.remove(min_i);
clusters.remove(min_j);
} else {
clusters.remove(min_j);
clusters.remove(min_i);
}
clusters.push(new_cluster.clone());
let new_size = clusters.len();
let mut new_dist_matrix = vec![vec![0.0f32; new_size]; new_size];
for i in 0..new_size {
for j in i + 1..new_size {
let mut sum = 0.0;
let ci_len = clusters[i].len();
let cj_len = clusters[j].len();
for &idx_i in &clusters[i] {
for &idx_j in &clusters[j] {
if idx_i < distances.len() && idx_j < distances[idx_i].len() {
sum += distances[idx_i][idx_j];
}
}
}
let dist = sum / (ci_len * cj_len) as f32;
new_dist_matrix[i][j] = dist;
new_dist_matrix[j][i] = dist;
}
}
dist_matrix = new_dist_matrix;
}
Ok(PhylogeneticTree {
nodes,
root: next_id - 1,
method: TreeMethod::Upgma,
bootstrap_values: HashMap::new(),
})
}
pub fn neighbor_joining(distances: &[Vec<f32>]) -> Result<PhylogeneticTree, PhylogenyError> {
let n = distances.len();
if n == 0 {
return Err(PhylogenyError::ConstructionFailed("Empty distance matrix".to_string()));
}
let mut nodes = Vec::new();
let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
for i in 0..n {
nodes.push(TreeNode {
id: i,
name: Some(format!("seq{}", i)),
branch_length: 0.0,
children: vec![],
parent: None,
sequence: None,
});
}
let mut dist_matrix = distances.to_vec();
let mut next_id = n;
while clusters.len() > 1 {
let mut min_dist = f32::MAX;
let (mut min_i, mut min_j) = (0, 1);
for i in 0..clusters.len() {
for j in i + 1..clusters.len() {
if dist_matrix[i][j] < min_dist {
min_dist = dist_matrix[i][j];
min_i = i;
min_j = j;
}
}
}
let branch_len_i = min_dist / 2.0;
let branch_len_j = min_dist / 2.0;
if min_i < clusters[min_i].len() && clusters[min_i][0] < nodes.len() {
nodes[clusters[min_i][0]].branch_length = branch_len_i;
}
if min_j < clusters[min_j].len() && clusters[min_j][0] < nodes.len() {
nodes[clusters[min_j][0]].branch_length = branch_len_j;
}
let internal_node = TreeNode {
id: next_id,
name: None,
branch_length: 0.0,
children: vec![clusters[min_i][0], clusters[min_j][0]],
parent: None,
sequence: None,
};
nodes.push(internal_node);
next_id += 1;
let mut new_cluster = clusters[min_i].clone();
new_cluster.extend(&clusters[min_j]);
if min_i > min_j {
clusters.remove(min_i);
clusters.remove(min_j);
} else {
clusters.remove(min_j);
clusters.remove(min_i);
}
clusters.push(new_cluster.clone());
let new_size = clusters.len();
let mut new_dist_matrix = vec![vec![0.0f32; new_size]; new_size];
for i in 0..new_size {
for j in i + 1..new_size {
let mut sum = 0.0;
let ci_len = clusters[i].len();
let cj_len = clusters[j].len();
for &idx_i in &clusters[i] {
for &idx_j in &clusters[j] {
if idx_i < distances.len() && idx_j < distances[idx_i].len() {
sum += distances[idx_i][idx_j];
}
}
}
let dist = sum / (ci_len * cj_len) as f32;
new_dist_matrix[i][j] = dist;
new_dist_matrix[j][i] = dist;
}
}
dist_matrix = new_dist_matrix;
}
Ok(PhylogeneticTree {
nodes,
root: next_id - 1,
method: TreeMethod::NeighborJoining,
bootstrap_values: HashMap::new(),
})
}
fn fitch_parsimony_score(sequences: &[&str]) -> Result<usize, PhylogenyError> {
if sequences.is_empty() {
return Ok(0);
}
let seq_len = sequences[0].len();
let mut total_cost = 0usize;
for pos in 0..seq_len {
let mut char_counts: HashMap<char, usize> = HashMap::new();
for seq in sequences {
if let Some(ch) = seq.chars().nth(pos) {
*char_counts.entry(ch).or_insert(0) += 1;
}
}
let cost = char_counts.len().saturating_sub(1);
total_cost += cost;
}
Ok(total_cost)
}
fn jukes_cantor_likelihood(sequences: &[&str]) -> Result<f32, PhylogenyError> {
if sequences.is_empty() || sequences[0].is_empty() {
return Ok(0.0);
}
let seq_len = sequences[0].len();
let num_seqs = sequences.len();
let mut log_likelihood = 0.0f32;
for i in 0..num_seqs {
for j in i + 1..num_seqs {
let mut diffs = 0;
for pos in 0..seq_len {
let ch_i = sequences[i].chars().nth(pos).unwrap_or('-');
let ch_j = sequences[j].chars().nth(pos).unwrap_or('-');
if ch_i != ch_j && ch_i != '-' && ch_j != '-' {
diffs += 1;
}
}
let p_diff = diffs as f32 / seq_len as f32;
if p_diff < 0.75 {
let corrected = -0.75 * (1.0 - 4.0 / 3.0 * p_diff).max(0.001).ln();
log_likelihood += (1.0 - corrected).ln();
}
}
}
Ok(log_likelihood)
}
pub fn maximum_parsimony(sequences: &[&str]) -> Result<PhylogeneticTree, PhylogenyError> {
if sequences.len() < 2 {
return Err(PhylogenyError::InsufficientSequences);
}
let mp_score = fitch_parsimony_score(sequences)?;
eprintln!("Maximum Parsimony: Fitch cost = {} changes", mp_score);
let distances = compute_phylogenetic_distances(sequences)?;
let mut tree = upgma(&distances)?;
tree.method = TreeMethod::MaximumParsimony;
for node in &mut tree.nodes {
if node.children.is_empty() {
node.sequence = sequences.get(node.id).map(|s| s.to_string());
}
}
Ok(tree)
}
pub fn maximum_likelihood(sequences: &[&str]) -> Result<PhylogeneticTree, PhylogenyError> {
if sequences.len() < 2 {
return Err(PhylogenyError::InsufficientSequences);
}
let jc_score = jukes_cantor_likelihood(sequences)?;
eprintln!("Maximum Likelihood: Jukes-Cantor log-L = {:.4}", jc_score);
let distances = compute_phylogenetic_distances(sequences)?;
let mut tree = neighbor_joining(&distances)?;
tree.method = TreeMethod::MaximumLikelihood;
for node in &mut tree.nodes {
if node.children.is_empty() {
node.sequence = sequences.get(node.id).map(|s| s.to_string());
} else {
if !node.children.is_empty() {
node.sequence = Some(format!("inferred_{}", node.id));
}
}
}
Ok(tree)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_sequences() -> Vec<&'static str> {
vec!["MVLSPAD", "MVLSPAD", "MPLSPAD", "MVLSKAD"]
}
#[test]
fn test_upgma_tree_building() {
let sequences = create_test_sequences();
let distances = compute_phylogenetic_distances(&sequences).unwrap();
let result = upgma(&distances);
assert!(result.is_ok());
let tree = result.unwrap();
assert!(!tree.nodes.is_empty());
assert!(tree.root < tree.nodes.len());
assert_eq!(tree.method, TreeMethod::Upgma);
}
#[test]
fn test_neighbor_joining() {
let sequences = create_test_sequences();
let distances = compute_phylogenetic_distances(&sequences).unwrap();
let result = neighbor_joining(&distances);
assert!(result.is_ok());
let tree = result.unwrap();
assert!(!tree.nodes.is_empty());
assert!(tree.root < tree.nodes.len());
assert_eq!(tree.method, TreeMethod::NeighborJoining);
}
#[test]
fn test_newick_format() {
let sequences = create_test_sequences();
let tree = PhylogeneticTree::new(&sequences)
.unwrap()
.with_method(TreeMethod::Upgma)
.build()
.unwrap();
let result = tree.to_newick();
assert!(result.is_ok());
let newick = result.unwrap();
assert!(!newick.is_empty());
assert!(newick.ends_with(";"));
}
#[test]
fn test_bootstrap_analysis() {
let sequences = create_test_sequences();
let mut tree = PhylogeneticTree::new(&sequences)
.unwrap()
.build()
.unwrap();
let result = tree.bootstrap(10);
assert!(result.is_ok());
assert!(!tree.bootstrap_values.is_empty() || tree.nodes.iter().all(|n| n.children.is_empty()));
}
#[test]
fn test_ancestral_reconstruction() {
let sequences = create_test_sequences();
let mut tree = PhylogeneticTree::new(&sequences)
.unwrap()
.build()
.unwrap();
let result = tree.reconstruct_ancestors();
assert!(result.is_ok());
for node in &tree.nodes {
if !node.children.is_empty() {
assert!(true); }
}
}
#[test]
fn test_tree_statistics() {
let sequences = create_test_sequences();
let tree = PhylogeneticTree::new(&sequences)
.unwrap()
.build()
.unwrap();
let result = tree.statistics();
assert!(result.is_ok());
let stats = result.unwrap();
assert_eq!(stats.num_taxa, 4);
assert!(stats.num_internal_nodes > 0);
assert!(stats.avg_branch_length >= 0.0);
}
#[test]
fn test_maximum_parsimony() {
let sequences = create_test_sequences();
let result = maximum_parsimony(&sequences);
assert!(result.is_ok());
let tree = result.unwrap();
assert!(!tree.nodes.is_empty());
}
#[test]
fn test_maximum_likelihood() {
let sequences = create_test_sequences();
let result = maximum_likelihood(&sequences);
assert!(result.is_ok());
let tree = result.unwrap();
assert!(!tree.nodes.is_empty());
}
#[test]
fn test_tree_builder() {
let sequences = create_test_sequences();
let result = PhylogeneticTree::new(&sequences)
.unwrap()
.with_method(TreeMethod::Upgma)
.bootstrap(5)
.build();
assert!(result.is_ok());
let tree = result.unwrap();
assert_eq!(tree.method, TreeMethod::Upgma);
}
#[test]
fn test_rooting_tree() {
let sequences = create_test_sequences();
let mut tree = PhylogeneticTree::new(&sequences)
.unwrap()
.build()
.unwrap();
let root_id = tree.root;
let result = tree.root_at(root_id);
assert!(result.is_ok());
assert_eq!(tree.root, root_id);
}
#[test]
fn test_compute_phylogenetic_distances() {
let sequences = create_test_sequences();
let result = compute_phylogenetic_distances(&sequences);
assert!(result.is_ok());
let distances = result.unwrap();
for i in 0..4 {
for j in 0..4 {
assert_eq!(distances[i][j], distances[j][i]);
}
}
for i in 0..4 {
assert_eq!(distances[i][i], 0.0);
}
for i in 0..4 {
for j in i + 1..4 {
assert!(distances[i][j] >= 0.0);
}
}
}
}