1pub mod alphabet;
49pub mod config;
50pub mod distance_matrix;
51pub mod models;
52pub mod msa;
53pub mod nj;
54pub mod tree;
55
56use bitvec::prelude::{BitVec, Lsb0, bitvec};
57use std::collections::HashMap;
58
59use crate::alphabet::{Alphabet, AlphabetEncoding, DNA, Protein};
60use crate::config::SubstitutionModel;
61pub use crate::config::{MSA, NJConfig, SequenceObject};
62use crate::distance_matrix::DistMat;
63use crate::models::{JukesCantor, Kimura2P, ModelCalculation, PDiff, Poisson};
64use crate::tree::{NameOrSupport, TreeNode};
65
66fn bitset_of(
72 node: &TreeNode,
73 idx: &HashMap<String, usize>,
74 out: &mut BitVec<u8, Lsb0>,
75) -> Result<(), String> {
76 match &node.children {
77 None => match &node.label {
78 Some(NameOrSupport::Name(name)) => {
79 let i = idx[name];
80 out.set(i, true);
81 Ok(())
82 }
83 _ => Err("Leaf node without a name label".into()),
84 },
85 Some([l, r]) => {
86 bitset_of(l, idx, out)?;
87 bitset_of(r, idx, out)?;
88 Ok(())
89 }
90 }
91}
92
93fn count_clades(
100 tree: &TreeNode,
101 idx: &HashMap<String, usize>,
102 n_taxa: usize,
103 counter: &mut HashMap<Vec<u8>, usize>,
104) -> Result<(), String> {
105 if let Some([l, r]) = &tree.children {
106 let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
108 bitset_of(tree, idx, &mut bv).unwrap();
109
110 let n = bv.count_ones();
111 if n > 1 && n < n_taxa {
112 counter
114 .entry(bv.as_raw_slice().to_vec())
115 .and_modify(|c| *c += 1)
116 .or_insert(1);
117 }
118
119 count_clades(l, idx, n_taxa, counter)?;
121 count_clades(r, idx, n_taxa, counter)?;
122 }
123 Ok(())
124}
125
126fn bootstrap_clade_counts<A: AlphabetEncoding, M: ModelCalculation<A>>(
128 msa: &MSA<A>,
129 n_bootstrap_samples: usize,
130) -> Result<Option<HashMap<Vec<u8>, usize>>, String> {
131 if n_bootstrap_samples == 0 {
132 return Ok(None);
133 }
134 let idx_map: HashMap<String, usize> = msa.to_index_map();
135 let mut counter = HashMap::new();
136 for _ in 0..n_bootstrap_samples {
137 let tree = msa
138 .bootstrap()?
139 .into_dist::<M>()
140 .neighbor_joining()
141 .expect("NJ bootstrap iteration failed");
142 count_clades(&tree, &idx_map, msa.n_sequences, &mut counter)?;
143 }
144 Ok(Some(counter))
145}
146
147fn add_bootstrap_to_tree(
154 node: &mut TreeNode,
155 idx: &HashMap<String, usize>,
156 n_taxa: usize,
157 counts: &HashMap<Vec<u8>, usize>,
158) {
159 if node.children.is_some() {
160 let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
162 bitset_of(node, idx, &mut bv).unwrap();
163
164 let n = bv.count_ones();
165 if n > 1 && n < n_taxa {
166 if let Some(c) = counts.get(&bv.as_raw_slice().to_vec()) {
167 node.label = Some(NameOrSupport::Support(*c));
168 }
169 }
170
171 if let Some([l, r]) = &mut node.children {
173 add_bootstrap_to_tree(l, idx, n_taxa, counts);
174 add_bootstrap_to_tree(r, idx, n_taxa, counts);
175 }
176 }
177}
178
179fn detect_alphabet(msa: &[SequenceObject]) -> Result<Alphabet, String> {
187 let mut is_protein = false;
189
190 for seq in msa {
191 for c in seq.sequence.bytes() {
192 match c.to_ascii_uppercase() {
193 b'A' | b'C' | b'G' | b'T' | b'U' | b'N' | b'-' => { }
194 _ => {
195 is_protein = true;
196 break;
197 }
198 }
199 }
200 if is_protein {
201 break;
202 }
203 }
204
205 Ok(if is_protein {
206 Alphabet::Protein
207 } else {
208 Alphabet::DNA
209 })
210}
211
212fn run_nj<A, M>(msa: MSA<A>, n_bootstrap_samples: usize) -> Result<String, String>
218where
219 A: AlphabetEncoding,
220 M: ModelCalculation<A>,
221{
222 let clade_counts = bootstrap_clade_counts::<A, M>(&msa, n_bootstrap_samples)?;
224
225 let mut main_tree = msa.into_dist::<M>().neighbor_joining()?;
226 let newick = match clade_counts {
227 Some(counts) => {
228 let main_idx_map: HashMap<String, usize> = msa.to_index_map();
229 add_bootstrap_to_tree(&mut main_tree, &main_idx_map, msa.n_sequences, &counts);
230 main_tree.to_newick()
231 }
232 None => main_tree.to_newick(),
233 };
234 Ok(newick)
235}
236
237pub fn nj(conf: NJConfig) -> Result<String, String> {
245 if conf.msa.is_empty() {
246 return Err("Input MSA is empty".into());
247 }
248 let alphabet = detect_alphabet(&conf.msa)?;
249 match alphabet {
250 Alphabet::DNA => {
251 let msa =
253 MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
254
255 match conf.substitution_model {
256 SubstitutionModel::PDiff => run_nj::<DNA, PDiff>(msa, conf.n_bootstrap_samples),
257 SubstitutionModel::JukesCantor => {
258 run_nj::<DNA, JukesCantor>(msa, conf.n_bootstrap_samples)
259 }
260 SubstitutionModel::Kimura2P => {
261 run_nj::<DNA, Kimura2P>(msa, conf.n_bootstrap_samples)
262 }
263 SubstitutionModel::Poisson => {
265 Err("Poisson is a protein model; cannot use with DNA".into())
266 }
267 }
268 }
269
270 Alphabet::Protein => {
271 let msa =
272 MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
273
274 match conf.substitution_model {
275 SubstitutionModel::Poisson => {
277 run_nj::<Protein, Poisson>(msa, conf.n_bootstrap_samples)
278 }
279 SubstitutionModel::PDiff => run_nj::<Protein, PDiff>(msa, conf.n_bootstrap_samples),
280 SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
282 Err("Selected model is for DNA; cannot use with Protein".into())
283 }
284 }
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::models::SubstitutionModel;
293
294 #[test]
295 fn test_nj_wrapper_simple_tree() {
296 let sequences = vec![
297 SequenceObject {
298 identifier: "A".into(),
299 sequence: "ACGTCG".into(),
300 },
301 SequenceObject {
302 identifier: "B".into(),
303 sequence: "ACG-GC".into(),
304 },
305 ];
306 let conf = NJConfig {
307 msa: sequences,
308 n_bootstrap_samples: 0,
309 substitution_model: SubstitutionModel::PDiff,
310 };
311 let newick = nj(conf).expect("NJ failed");
312 assert_eq!(newick, "(A:0.167,B:0.167);");
313 }
314
315 #[test]
316 fn test_nj_wrapper_adds_semicolon() {
317 let sequences = vec![
318 SequenceObject {
319 identifier: "Seq0".into(),
320 sequence: "A".into(),
321 },
322 SequenceObject {
323 identifier: "Seq1".into(),
324 sequence: "A".into(),
325 },
326 ];
327 let conf = NJConfig {
328 msa: sequences,
329 n_bootstrap_samples: 0,
330 substitution_model: SubstitutionModel::PDiff,
331 };
332 let out = nj(conf).unwrap();
333 assert!(out.ends_with(';'));
334 }
335
336 #[test]
337 fn test_nj_deterministic_order() {
338 let sequences = vec![
339 SequenceObject {
340 identifier: "Seq0".into(),
341 sequence: "ACGTCG".into(),
342 },
343 SequenceObject {
344 identifier: "Seq1".into(),
345 sequence: "ACG-GC".into(),
346 },
347 SequenceObject {
348 identifier: "Seq2".into(),
349 sequence: "ACGCGT".into(),
350 },
351 ];
352 let conf = NJConfig {
353 msa: sequences,
354 n_bootstrap_samples: 0,
355 substitution_model: SubstitutionModel::PDiff,
356 };
357
358 let t1 = nj(conf.clone()).unwrap();
359 let t2 = nj(conf).unwrap();
360 assert_eq!(t1, t2);
361 }
362
363 #[test]
364 fn test_nj_wrapper_empty_msa() {
365 let conf = NJConfig {
366 msa: vec![],
367 n_bootstrap_samples: 0,
368 substitution_model: SubstitutionModel::PDiff,
369 };
370 let result = nj(conf);
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn test_nj_wrapper_incorrect_model_for_alphabet() {
376 let sequences = vec![
377 SequenceObject {
378 identifier: "Seq0".into(),
379 sequence: "ACGTCG".into(),
380 },
381 SequenceObject {
382 identifier: "Seq1".into(),
383 sequence: "ACG-GC".into(),
384 },
385 ];
386 let conf = NJConfig {
387 msa: sequences,
388 n_bootstrap_samples: 0,
389 substitution_model: SubstitutionModel::Poisson, };
391 let result = nj(conf);
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn test_nj_wrapper_incorrect_model_for_protein() {
397 let sequences = vec![
398 SequenceObject {
399 identifier: "Seq0".into(),
400 sequence: "ACDEFGH".into(),
401 },
402 SequenceObject {
403 identifier: "Seq1".into(),
404 sequence: "ACD-FGH".into(),
405 },
406 ];
407 let conf = NJConfig {
408 msa: sequences,
409 n_bootstrap_samples: 0,
410 substitution_model: SubstitutionModel::JukesCantor, };
412 let result = nj(conf);
413 assert!(result.is_err());
414 }
415
416 #[test]
417 fn test_detect_alphabet_dna() {
418 let msa = vec![
419 SequenceObject {
420 identifier: "Seq0".into(),
421 sequence: "ACGTACGT".into(),
422 },
423 SequenceObject {
424 identifier: "Seq1".into(),
425 sequence: "ACG-ACGT".into(),
426 },
427 ];
428 let alphabet = detect_alphabet(&msa).expect("detection failed");
429 assert_eq!(alphabet, Alphabet::DNA);
430 }
431
432 #[test]
433 fn test_detect_alphabet_protein() {
434 let msa = vec![
435 SequenceObject {
436 identifier: "Seq0".into(),
437 sequence: "ACDEFGHIK".into(),
438 },
439 SequenceObject {
440 identifier: "Seq1".into(),
441 sequence: "ACD-FGHIK".into(),
442 },
443 ];
444 let alphabet = detect_alphabet(&msa).expect("detection failed");
445 assert_eq!(alphabet, Alphabet::Protein);
446 }
447}