1pub mod alphabet;
49pub mod config;
50pub mod distance_matrix;
51pub mod error;
52pub mod event;
53pub mod fasta;
54pub mod models;
55pub mod msa;
56pub mod nj;
57pub mod tree;
58
59use bitvec::prelude::{BitVec, Lsb0, bitvec};
60use std::collections::HashMap;
61
62use crate::alphabet::{Alphabet, AlphabetEncoding, DNA, Protein};
63use crate::config::SubstitutionModel;
64pub use crate::config::{DistConfig, MSA, NJConfig, NJResult, SequenceObject};
65use crate::distance_matrix::DistMat;
66pub use crate::distance_matrix::DistanceResult;
67pub use crate::error::NJError;
68pub use crate::event::{LogLevel, NJEvent};
69pub use crate::fasta::parse_fasta;
70use crate::models::{JukesCantor, Kimura2P, ModelCalculation, PDiff, Poisson};
71use crate::tree::{NameOrSupport, TreeNode};
72
73fn bitset_of(
79 node: &TreeNode,
80 idx: &HashMap<String, usize>,
81 out: &mut BitVec<u8, Lsb0>,
82) -> Result<(), String> {
83 match &node.children {
84 None => match &node.label {
85 Some(NameOrSupport::Name(name)) => {
86 let i = idx[name];
87 out.set(i, true);
88 Ok(())
89 }
90 _ => Err("Leaf node without a name label".into()),
91 },
92 Some([l, r]) => {
93 bitset_of(l, idx, out)?;
94 bitset_of(r, idx, out)?;
95 Ok(())
96 }
97 }
98}
99
100fn count_clades(
107 tree: &TreeNode,
108 idx: &HashMap<String, usize>,
109 n_taxa: usize,
110 counter: &mut HashMap<Vec<u8>, usize>,
111) -> Result<(), String> {
112 if let Some([l, r]) = &tree.children {
113 let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
114 bitset_of(tree, idx, &mut bv)?;
115
116 let n = bv.count_ones();
117 if n > 1 && n < n_taxa {
118 counter
119 .entry(bv.as_raw_slice().to_vec())
120 .and_modify(|c| *c += 1)
121 .or_insert(1);
122 }
123
124 count_clades(l, idx, n_taxa, counter)?;
125 count_clades(r, idx, n_taxa, counter)?;
126 }
127 Ok(())
128}
129
130#[cfg(feature = "parallel")]
135pub(crate) fn build_thread_pool(num_threads: Option<usize>) -> Result<rayon::ThreadPool, String> {
136 let mut builder = rayon::ThreadPoolBuilder::new();
137 if let Some(n) = num_threads {
138 builder = builder.num_threads(n);
139 }
140 builder.build().map_err(|e| e.to_string())
141}
142
143#[cfg(feature = "parallel")]
151fn bootstrap_clade_counts_parallel<A, M>(
152 msa: &MSA<A>,
153 n_bootstrap_samples: usize,
154 idx_map: &HashMap<String, usize>,
155 n_taxa: usize,
156 on_event: Option<&dyn Fn(NJEvent)>,
157 num_threads: Option<usize>,
158) -> Result<HashMap<Vec<u8>, usize>, String>
159where
160 A: AlphabetEncoding + Send + Sync,
161 A::Symbol: Send + Sync,
162 M: ModelCalculation<A> + Send + Sync,
163{
164 use rayon::iter::{IntoParallelIterator, ParallelIterator};
165 use std::sync::mpsc;
166
167 let pool = build_thread_pool(num_threads)?;
168 let (tx, rx) = mpsc::channel::<Result<HashMap<Vec<u8>, usize>, String>>();
169 let mut counter: HashMap<Vec<u8>, usize> = HashMap::new();
170
171 std::thread::scope(|scope| -> Result<(), String> {
172 scope.spawn(|| {
177 pool.install(|| {
178 (0..n_bootstrap_samples)
179 .into_par_iter()
180 .for_each_with(tx, |sender, _| {
181 let result: Result<HashMap<Vec<u8>, usize>, String> = (|| {
182 let tree = msa
183 .bootstrap()?
184 .into_dist::<M>()
185 .neighbor_joining()
186 .expect("NJ bootstrap iteration failed");
187 let mut local = HashMap::new();
188 count_clades(&tree, idx_map, n_taxa, &mut local)?;
189 Ok(local)
190 })(
191 );
192 let _ = sender.send(result);
195 });
196 });
197 });
198
199 for completed in 1..=n_bootstrap_samples {
201 match rx.recv() {
202 Ok(Ok(local)) => {
203 for (clade, count) in local {
204 *counter.entry(clade).or_insert(0) += count;
205 }
206 }
207 Ok(Err(e)) => return Err(e),
208 Err(_) => return Err("bootstrap channel closed unexpectedly".into()),
209 }
210 if let Some(cb) = on_event {
211 cb(NJEvent::BootstrapProgress {
212 completed,
213 total: n_bootstrap_samples,
214 });
215 }
216 }
217 Ok(())
218 })?;
219
220 Ok(counter)
221}
222
223fn bootstrap_clade_counts<A, M>(
230 msa: &MSA<A>,
231 n_bootstrap_samples: usize,
232 on_event: Option<&dyn Fn(NJEvent)>,
233 num_threads: Option<usize>,
234) -> Result<Option<HashMap<Vec<u8>, usize>>, String>
235where
236 A: AlphabetEncoding + Send + Sync,
237 A::Symbol: Send + Sync,
238 M: ModelCalculation<A> + Send + Sync,
239{
240 if n_bootstrap_samples == 0 {
241 return Ok(None);
242 }
243 if let Some(cb) = on_event {
244 cb(NJEvent::BootstrapStarted {
245 total: n_bootstrap_samples,
246 });
247 }
248 let idx_map: HashMap<String, usize> = msa.to_index_map();
249 let n_taxa = msa.n_sequences;
250
251 #[cfg(feature = "parallel")]
252 let counter = bootstrap_clade_counts_parallel::<A, M>(
253 msa,
254 n_bootstrap_samples,
255 &idx_map,
256 n_taxa,
257 on_event,
258 num_threads,
259 )?;
260
261 #[cfg(not(feature = "parallel"))]
262 let counter = {
263 let _ = num_threads;
264 let mut c = HashMap::new();
265 for i in 0..n_bootstrap_samples {
266 let tree = msa
267 .bootstrap()?
268 .into_dist::<M>()
269 .neighbor_joining()
270 .expect("NJ bootstrap iteration failed");
271 count_clades(&tree, &idx_map, n_taxa, &mut c)?;
272 if let Some(cb) = on_event {
273 cb(NJEvent::BootstrapProgress {
274 completed: i + 1,
275 total: n_bootstrap_samples,
276 });
277 }
278 }
279 c
280 };
281
282 Ok(Some(counter))
283}
284
285fn add_bootstrap_to_tree(
292 node: &mut TreeNode,
293 idx: &HashMap<String, usize>,
294 n_taxa: usize,
295 counts: &HashMap<Vec<u8>, usize>,
296 n_bootstrap_samples: usize,
297) -> Result<(), String> {
298 if node.children.is_some() {
299 let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
300 bitset_of(node, idx, &mut bv)?;
301
302 let n = bv.count_ones();
303 if n > 1 && n < n_taxa {
304 if let Some(c) = counts.get(&bv.as_raw_slice().to_vec()) {
305 let pct = c * 100 / n_bootstrap_samples;
306 node.label = Some(NameOrSupport::Support(pct));
307 }
308 }
309
310 if let Some([l, r]) = &mut node.children {
311 add_bootstrap_to_tree(l, idx, n_taxa, counts, n_bootstrap_samples)?;
312 add_bootstrap_to_tree(r, idx, n_taxa, counts, n_bootstrap_samples)?;
313 }
314 }
315 Ok(())
316}
317
318fn validate_msa(msa: &[SequenceObject]) -> Result<(), NJError> {
320 if msa.is_empty() {
321 return Err(NJError::EmptyMsa);
322 }
323 let expected_len = msa[0].sequence.len();
324 if expected_len == 0 {
325 return Err(NJError::EmptySequence);
326 }
327 for s in msa {
328 if s.sequence.len() != expected_len {
329 return Err(NJError::SequenceLengthMismatch {
330 expected: expected_len,
331 got: s.sequence.len(),
332 identifier: s.identifier.clone(),
333 });
334 }
335 }
336 Ok(())
337}
338
339fn detect_alphabet(msa: &[SequenceObject]) -> Alphabet {
347 let mut is_protein = false;
348
349 'outer: for seq in msa {
350 for c in seq.sequence.bytes() {
351 match c.to_ascii_uppercase() {
352 b'A' | b'C' | b'G' | b'T' | b'U' | b'N' | b'-' | b'R' | b'Y' | b'S' | b'W'
353 | b'K' | b'M' | b'B' | b'D' | b'H' | b'V' => { }
354 _ => {
355 is_protein = true;
356 break 'outer;
357 }
358 }
359 }
360 }
361
362 if is_protein {
363 Alphabet::Protein
364 } else {
365 Alphabet::DNA
366 }
367}
368
369fn run_distance_matrix<A, M>(
371 msa: MSA<A>,
372 num_threads: Option<usize>,
373) -> Result<DistanceResult, String>
374where
375 A: AlphabetEncoding + Send + Sync,
376 A::Symbol: Send + Sync,
377 M: ModelCalculation<A> + Send + Sync,
378{
379 #[cfg(feature = "parallel")]
380 {
381 let pool = build_thread_pool(num_threads)?;
382 Ok(pool.install(|| msa.into_dist::<M>()).into_result())
383 }
384 #[cfg(not(feature = "parallel"))]
385 {
386 let _ = num_threads;
387 Ok(msa.into_dist::<M>().into_result())
388 }
389}
390
391fn run_average_distance<A, M>(msa: MSA<A>, num_threads: Option<usize>) -> Result<f64, String>
393where
394 A: AlphabetEncoding + Send + Sync,
395 A::Symbol: Send + Sync,
396 M: ModelCalculation<A> + Send + Sync,
397{
398 #[cfg(feature = "parallel")]
399 {
400 let pool = build_thread_pool(num_threads)?;
401 Ok(pool.install(|| msa.into_dist::<M>()).average())
402 }
403 #[cfg(not(feature = "parallel"))]
404 {
405 let _ = num_threads;
406 Ok(msa.into_dist::<M>().average())
407 }
408}
409
410fn run_nj<A, M>(
419 msa: MSA<A>,
420 n_bootstrap_samples: usize,
421 on_event: Option<&dyn Fn(NJEvent)>,
422 num_threads: Option<usize>,
423 include_distance_matrix: bool,
424 include_average_distance: bool,
425) -> Result<NJResult, String>
426where
427 A: AlphabetEncoding + Send + Sync,
428 A::Symbol: Send + Sync,
429 M: ModelCalculation<A> + Send + Sync,
430{
431 let clade_counts =
432 bootstrap_clade_counts::<A, M>(&msa, n_bootstrap_samples, on_event, num_threads)?;
433
434 if let Some(cb) = on_event {
435 cb(NJEvent::ComputingDistances);
436 }
437
438 #[cfg(feature = "parallel")]
439 let dist = {
440 let pool = build_thread_pool(num_threads)?;
441 pool.install(|| msa.into_dist::<M>())
442 };
443 #[cfg(not(feature = "parallel"))]
444 let dist = msa.into_dist::<M>();
445
446 let distance_matrix = if include_distance_matrix {
447 Some(dist.to_result())
448 } else {
449 None
450 };
451 let average_distance = if include_average_distance {
452 Some(dist.average())
453 } else {
454 None
455 };
456
457 if let Some(cb) = on_event {
458 cb(NJEvent::RunningNJ);
459 }
460
461 let mut main_tree = dist.neighbor_joining()?;
462
463 let newick = match clade_counts {
464 Some(counts) => {
465 if let Some(cb) = on_event {
466 cb(NJEvent::AnnotatingBootstrap);
467 }
468 let main_idx_map: HashMap<String, usize> = msa.to_index_map();
469 add_bootstrap_to_tree(
470 &mut main_tree,
471 &main_idx_map,
472 msa.n_sequences,
473 &counts,
474 n_bootstrap_samples,
475 )?;
476 main_tree.to_newick()
477 }
478 None => main_tree.to_newick(),
479 };
480 Ok(NJResult {
481 newick,
482 distance_matrix,
483 average_distance,
484 })
485}
486
487pub fn nj(conf: NJConfig, on_event: Option<Box<dyn Fn(NJEvent)>>) -> Result<NJResult, NJError> {
503 let cb = on_event.as_deref();
504 let num_threads = conf.num_threads;
505 let include_distance_matrix = conf.return_distance_matrix;
506 let include_average_distance = conf.return_average_distance;
507 validate_msa(&conf.msa)?;
508 let n_sites = conf.msa[0].sequence.len();
509 if let Some(cb) = cb {
510 cb(NJEvent::MsaValidated {
511 n_sequences: conf.msa.len(),
512 n_sites,
513 });
514 }
515 let alphabet = conf.alphabet.unwrap_or_else(|| detect_alphabet(&conf.msa));
516 if let Some(cb) = cb {
517 cb(NJEvent::AlphabetDetected {
518 alphabet: alphabet.clone(),
519 });
520 }
521 let model = conf.substitution_model;
522 match alphabet {
523 Alphabet::DNA => {
524 let msa =
525 MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
526 match model {
527 SubstitutionModel::PDiff => run_nj::<DNA, PDiff>(
528 msa,
529 conf.n_bootstrap_samples,
530 cb,
531 num_threads,
532 include_distance_matrix,
533 include_average_distance,
534 )
535 .map_err(NJError::AlgorithmFailure),
536 SubstitutionModel::JukesCantor => run_nj::<DNA, JukesCantor>(
537 msa,
538 conf.n_bootstrap_samples,
539 cb,
540 num_threads,
541 include_distance_matrix,
542 include_average_distance,
543 )
544 .map_err(NJError::AlgorithmFailure),
545 SubstitutionModel::Kimura2P => run_nj::<DNA, Kimura2P>(
546 msa,
547 conf.n_bootstrap_samples,
548 cb,
549 num_threads,
550 include_distance_matrix,
551 include_average_distance,
552 )
553 .map_err(NJError::AlgorithmFailure),
554 SubstitutionModel::Poisson => Err(NJError::IncompatibleModel {
555 model,
556 alphabet: Alphabet::DNA,
557 }),
558 }
559 }
560 Alphabet::Protein => {
561 let msa =
562 MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
563 match model {
564 SubstitutionModel::Poisson => run_nj::<Protein, Poisson>(
565 msa,
566 conf.n_bootstrap_samples,
567 cb,
568 num_threads,
569 include_distance_matrix,
570 include_average_distance,
571 )
572 .map_err(NJError::AlgorithmFailure),
573 SubstitutionModel::PDiff => run_nj::<Protein, PDiff>(
574 msa,
575 conf.n_bootstrap_samples,
576 cb,
577 num_threads,
578 include_distance_matrix,
579 include_average_distance,
580 )
581 .map_err(NJError::AlgorithmFailure),
582 SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
583 Err(NJError::IncompatibleModel {
584 model,
585 alphabet: Alphabet::Protein,
586 })
587 }
588 }
589 }
590 }
591}
592
593pub fn distance_matrix(conf: DistConfig) -> Result<DistanceResult, NJError> {
600 let num_threads = conf.num_threads;
601 validate_msa(&conf.msa)?;
602 let alphabet = conf.alphabet.unwrap_or_else(|| detect_alphabet(&conf.msa));
603 let model = conf.substitution_model;
604 match alphabet {
605 Alphabet::DNA => {
606 let msa =
607 MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
608 match model {
609 SubstitutionModel::PDiff => run_distance_matrix::<DNA, PDiff>(msa, num_threads)
610 .map_err(NJError::AlgorithmFailure),
611 SubstitutionModel::JukesCantor => {
612 run_distance_matrix::<DNA, JukesCantor>(msa, num_threads)
613 .map_err(NJError::AlgorithmFailure)
614 }
615 SubstitutionModel::Kimura2P => {
616 run_distance_matrix::<DNA, Kimura2P>(msa, num_threads)
617 .map_err(NJError::AlgorithmFailure)
618 }
619 SubstitutionModel::Poisson => Err(NJError::IncompatibleModel {
620 model,
621 alphabet: Alphabet::DNA,
622 }),
623 }
624 }
625 Alphabet::Protein => {
626 let msa =
627 MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
628 match model {
629 SubstitutionModel::Poisson => {
630 run_distance_matrix::<Protein, Poisson>(msa, num_threads)
631 .map_err(NJError::AlgorithmFailure)
632 }
633 SubstitutionModel::PDiff => run_distance_matrix::<Protein, PDiff>(msa, num_threads)
634 .map_err(NJError::AlgorithmFailure),
635 SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
636 Err(NJError::IncompatibleModel {
637 model,
638 alphabet: Alphabet::Protein,
639 })
640 }
641 }
642 }
643 }
644}
645
646pub fn average_distance(conf: DistConfig) -> Result<f64, NJError> {
652 let num_threads = conf.num_threads;
653 validate_msa(&conf.msa)?;
654 let alphabet = conf.alphabet.unwrap_or_else(|| detect_alphabet(&conf.msa));
655 let model = conf.substitution_model;
656 match alphabet {
657 Alphabet::DNA => {
658 let msa =
659 MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
660 match model {
661 SubstitutionModel::PDiff => run_average_distance::<DNA, PDiff>(msa, num_threads)
662 .map_err(NJError::AlgorithmFailure),
663 SubstitutionModel::JukesCantor => {
664 run_average_distance::<DNA, JukesCantor>(msa, num_threads)
665 .map_err(NJError::AlgorithmFailure)
666 }
667 SubstitutionModel::Kimura2P => {
668 run_average_distance::<DNA, Kimura2P>(msa, num_threads)
669 .map_err(NJError::AlgorithmFailure)
670 }
671 SubstitutionModel::Poisson => Err(NJError::IncompatibleModel {
672 model,
673 alphabet: Alphabet::DNA,
674 }),
675 }
676 }
677 Alphabet::Protein => {
678 let msa =
679 MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
680 match model {
681 SubstitutionModel::Poisson => {
682 run_average_distance::<Protein, Poisson>(msa, num_threads)
683 .map_err(NJError::AlgorithmFailure)
684 }
685 SubstitutionModel::PDiff => {
686 run_average_distance::<Protein, PDiff>(msa, num_threads)
687 .map_err(NJError::AlgorithmFailure)
688 }
689 SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
690 Err(NJError::IncompatibleModel {
691 model,
692 alphabet: Alphabet::Protein,
693 })
694 }
695 }
696 }
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use crate::config::DistConfig;
704 use crate::models::SubstitutionModel;
705
706 fn nj_conf(pairs: &[(&str, &str)], include_dm: bool, include_avg: bool) -> NJConfig {
708 NJConfig {
709 msa: pairs
710 .iter()
711 .map(|(id, seq)| SequenceObject {
712 identifier: id.to_string(),
713 sequence: seq.to_string(),
714 })
715 .collect(),
716 n_bootstrap_samples: 0,
717 substitution_model: SubstitutionModel::PDiff,
718 alphabet: None,
719 num_threads: None,
720 return_distance_matrix: include_dm,
721 return_average_distance: include_avg,
722 }
723 }
724
725 #[test]
726 fn test_nj_wrapper_simple_tree() {
727 let sequences = vec![
730 SequenceObject {
731 identifier: "A".into(),
732 sequence: "ACGTCG".into(),
733 },
734 SequenceObject {
735 identifier: "B".into(),
736 sequence: "ACG-GC".into(),
737 },
738 ];
739 let conf = NJConfig {
740 msa: sequences,
741 n_bootstrap_samples: 0,
742 substitution_model: SubstitutionModel::PDiff,
743 alphabet: None,
744 num_threads: None,
745 return_distance_matrix: false,
746 return_average_distance: false,
747 };
748 let result = nj(conf, None).expect("NJ failed");
749 assert_eq!(result.newick, "(A:0.200,B:0.200);");
750 }
751
752 #[test]
753 fn test_nj_wrapper_adds_semicolon() {
754 let sequences = vec![
755 SequenceObject {
756 identifier: "Seq0".into(),
757 sequence: "A".into(),
758 },
759 SequenceObject {
760 identifier: "Seq1".into(),
761 sequence: "A".into(),
762 },
763 ];
764 let conf = NJConfig {
765 msa: sequences,
766 n_bootstrap_samples: 0,
767 substitution_model: SubstitutionModel::PDiff,
768 alphabet: None,
769 num_threads: None,
770 return_distance_matrix: false,
771 return_average_distance: false,
772 };
773 let out = nj(conf, None).unwrap();
774 assert!(out.newick.ends_with(';'));
775 }
776
777 #[test]
778 fn test_nj_deterministic_order() {
779 let sequences = vec![
780 SequenceObject {
781 identifier: "Seq0".into(),
782 sequence: "ACGTCG".into(),
783 },
784 SequenceObject {
785 identifier: "Seq1".into(),
786 sequence: "ACG-GC".into(),
787 },
788 SequenceObject {
789 identifier: "Seq2".into(),
790 sequence: "ACGCGT".into(),
791 },
792 ];
793 let conf = NJConfig {
794 msa: sequences,
795 n_bootstrap_samples: 0,
796 substitution_model: SubstitutionModel::PDiff,
797 alphabet: None,
798 num_threads: None,
799 return_distance_matrix: false,
800 return_average_distance: false,
801 };
802
803 let t1 = nj(conf.clone(), None).unwrap();
804 let t2 = nj(conf, None).unwrap();
805 assert_eq!(t1, t2);
806 }
807
808 #[test]
809 fn test_nj_wrapper_empty_msa() {
810 let conf = NJConfig {
811 msa: vec![],
812 n_bootstrap_samples: 0,
813 substitution_model: SubstitutionModel::PDiff,
814 alphabet: None,
815 num_threads: None,
816 return_distance_matrix: false,
817 return_average_distance: false,
818 };
819 let result = nj(conf, None);
820 assert!(result.is_err());
821 }
822
823 #[test]
824 fn test_nj_wrapper_incorrect_model_for_alphabet() {
825 let sequences = vec![
826 SequenceObject {
827 identifier: "Seq0".into(),
828 sequence: "ACGTCG".into(),
829 },
830 SequenceObject {
831 identifier: "Seq1".into(),
832 sequence: "ACG-GC".into(),
833 },
834 ];
835 let conf = NJConfig {
836 msa: sequences,
837 n_bootstrap_samples: 0,
838 substitution_model: SubstitutionModel::Poisson, alphabet: None,
840 num_threads: None,
841 return_distance_matrix: false,
842 return_average_distance: false,
843 };
844 let result = nj(conf, None);
845 assert!(result.is_err());
846 }
847
848 #[test]
849 fn test_nj_wrapper_incorrect_model_for_protein() {
850 let sequences = vec![
851 SequenceObject {
852 identifier: "Seq0".into(),
853 sequence: "ACDEFGH".into(),
854 },
855 SequenceObject {
856 identifier: "Seq1".into(),
857 sequence: "ACD-FGH".into(),
858 },
859 ];
860 let conf = NJConfig {
861 msa: sequences,
862 n_bootstrap_samples: 0,
863 substitution_model: SubstitutionModel::JukesCantor, alphabet: None,
865 num_threads: None,
866 return_distance_matrix: false,
867 return_average_distance: false,
868 };
869 let result = nj(conf, None);
870 assert!(result.is_err());
871 }
872
873 #[test]
876 fn test_nj_result_no_extras_by_default() {
877 let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], false, false), None).unwrap();
878 assert!(result.distance_matrix.is_none());
879 assert!(result.average_distance.is_none());
880 }
881
882 #[test]
883 fn test_nj_result_include_distance_matrix() {
884 let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], true, false), None).unwrap();
885 let dm = result
886 .distance_matrix
887 .expect("distance_matrix should be present");
888 assert_eq!(dm.names, vec!["A", "B"]);
889 assert_eq!(dm.matrix.len(), 2);
890 assert_eq!(dm.matrix[0][0], 0.0);
891 assert!((dm.matrix[0][1] - 0.25).abs() < 1e-12);
892 assert!(result.average_distance.is_none());
893 }
894
895 #[test]
896 fn test_nj_result_include_average_distance() {
897 let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], false, true), None).unwrap();
898 let avg = result
899 .average_distance
900 .expect("average_distance should be present");
901 assert!((avg - 0.25).abs() < 1e-12);
902 assert!(result.distance_matrix.is_none());
903 }
904
905 #[test]
906 fn test_nj_result_include_both() {
907 let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], true, true), None).unwrap();
908 let dm = result
909 .distance_matrix
910 .as_ref()
911 .expect("distance_matrix should be present");
912 let avg = result
913 .average_distance
914 .expect("average_distance should be present");
915 assert!((dm.matrix[0][1] - avg).abs() < 1e-12);
917 }
918
919 #[test]
920 fn test_nj_result_distance_matrix_consistent_with_newick() {
921 let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], true, false), None).unwrap();
923 let dm = result.distance_matrix.unwrap();
924 assert_eq!(dm.matrix[0][0], 0.0);
926 assert_eq!(dm.matrix[1][1], 0.0);
927 assert!((dm.matrix[0][1] - dm.matrix[1][0]).abs() < 1e-12);
928 }
929
930 fn dist_conf(pairs: &[(&str, &str)], model: SubstitutionModel) -> DistConfig {
933 DistConfig {
934 msa: pairs
935 .iter()
936 .map(|(id, seq)| SequenceObject {
937 identifier: id.to_string(),
938 sequence: seq.to_string(),
939 })
940 .collect(),
941 substitution_model: model,
942 alphabet: None,
943 num_threads: None,
944 }
945 }
946
947 #[test]
948 fn test_distance_matrix_names_and_shape() {
949 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
950 let result = distance_matrix(conf).unwrap();
951 assert_eq!(result.names, vec!["A", "B"]);
952 assert_eq!(result.matrix.len(), 2);
953 assert_eq!(result.matrix[0].len(), 2);
954 assert_eq!(result.matrix[1].len(), 2);
955 }
956
957 #[test]
958 fn test_distance_matrix_diagonal_zero() {
959 let conf = dist_conf(
960 &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
961 SubstitutionModel::PDiff,
962 );
963 let result = distance_matrix(conf).unwrap();
964 for i in 0..3 {
965 assert_eq!(result.matrix[i][i], 0.0);
966 }
967 }
968
969 #[test]
970 fn test_distance_matrix_symmetric() {
971 let conf = dist_conf(
972 &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
973 SubstitutionModel::PDiff,
974 );
975 let result = distance_matrix(conf).unwrap();
976 for i in 0..3 {
977 for j in 0..3 {
978 assert_eq!(result.matrix[i][j], result.matrix[j][i]);
979 }
980 }
981 }
982
983 #[test]
984 fn test_distance_matrix_pdiff_known_value() {
985 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
987 let result = distance_matrix(conf).unwrap();
988 assert!((result.matrix[0][1] - 0.25).abs() < 1e-12);
989 assert!((result.matrix[1][0] - 0.25).abs() < 1e-12);
990 }
991
992 #[test]
993 fn test_distance_matrix_identical_sequences_zero() {
994 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGT")], SubstitutionModel::PDiff);
995 let result = distance_matrix(conf).unwrap();
996 assert_eq!(result.matrix[0][1], 0.0);
997 }
998
999 #[test]
1000 fn test_distance_matrix_jukes_cantor_dna() {
1001 let conf = dist_conf(
1002 &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
1003 SubstitutionModel::JukesCantor,
1004 );
1005 let result = distance_matrix(conf).unwrap();
1006 let expected = -0.75_f64 * (1.0_f64 - (4.0_f64 / 3.0) * 0.25).ln();
1008 assert!((result.matrix[0][1] - expected).abs() < 1e-10);
1009 }
1010
1011 #[test]
1012 fn test_distance_matrix_kimura2p_dna() {
1013 let conf = dist_conf(
1014 &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
1015 SubstitutionModel::Kimura2P,
1016 );
1017 let result = distance_matrix(conf).unwrap();
1018 assert_eq!(result.names, vec!["A", "B", "C"]);
1019 assert!(result.matrix[0][0] == 0.0);
1020 }
1021
1022 #[test]
1023 fn test_distance_matrix_poisson_protein() {
1024 let conf = dist_conf(
1025 &[("A", "ACDEFGH"), ("B", "ACDEFGK")],
1026 SubstitutionModel::Poisson,
1027 );
1028 let result = distance_matrix(conf).unwrap();
1029 let expected = -(1.0_f64 - 1.0 / 7.0).ln();
1031 assert!((result.matrix[0][1] - expected).abs() < 1e-10);
1032 }
1033
1034 #[test]
1035 fn test_distance_matrix_pdiff_protein() {
1036 let conf = dist_conf(
1037 &[("A", "ACDEFGH"), ("B", "ACDEFGK")],
1038 SubstitutionModel::PDiff,
1039 );
1040 let result = distance_matrix(conf).unwrap();
1041 assert!((result.matrix[0][1] - 1.0 / 7.0).abs() < 1e-12);
1042 }
1043
1044 #[test]
1045 fn test_distance_matrix_empty_msa_errors() {
1046 let conf = DistConfig {
1047 msa: vec![],
1048 substitution_model: SubstitutionModel::PDiff,
1049 alphabet: None,
1050 num_threads: None,
1051 };
1052 assert!(distance_matrix(conf).is_err());
1053 }
1054
1055 #[test]
1056 fn test_distance_matrix_incompatible_model_errors() {
1057 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::Poisson);
1059 assert!(distance_matrix(conf).is_err());
1060 let conf = dist_conf(
1062 &[("A", "ACDEFGH"), ("B", "ACDEFGK")],
1063 SubstitutionModel::JukesCantor,
1064 );
1065 assert!(distance_matrix(conf).is_err());
1066 }
1067
1068 #[test]
1071 fn test_average_distance_identical_sequences_zero() {
1072 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGT")], SubstitutionModel::PDiff);
1073 let avg = average_distance(conf).unwrap();
1074 assert_eq!(avg, 0.0);
1075 }
1076
1077 #[test]
1078 fn test_average_distance_two_taxa_equals_pairwise() {
1079 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
1081 let avg = average_distance(conf).unwrap();
1082 assert!((avg - 0.25).abs() < 1e-12);
1083 }
1084
1085 #[test]
1086 fn test_average_distance_three_taxa_known_value() {
1087 let conf = dist_conf(
1089 &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
1090 SubstitutionModel::PDiff,
1091 );
1092 let avg = average_distance(conf).unwrap();
1093 assert!((avg - 1.0 / 3.0).abs() < 1e-12);
1094 }
1095
1096 #[test]
1097 fn test_average_distance_jukes_cantor_dna() {
1098 let conf = dist_conf(
1099 &[("A", "ACGT"), ("B", "ACGA")],
1100 SubstitutionModel::JukesCantor,
1101 );
1102 let avg = average_distance(conf).unwrap();
1103 let expected = -0.75_f64 * (1.0_f64 - (4.0_f64 / 3.0) * 0.25).ln();
1104 assert!((avg - expected).abs() < 1e-10);
1105 }
1106
1107 #[test]
1108 fn test_average_distance_empty_msa_errors() {
1109 let conf = DistConfig {
1110 msa: vec![],
1111 substitution_model: SubstitutionModel::PDiff,
1112 alphabet: None,
1113 num_threads: None,
1114 };
1115 assert!(average_distance(conf).is_err());
1116 }
1117
1118 #[test]
1119 fn test_average_distance_incompatible_model_errors() {
1120 let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::Poisson);
1121 assert!(average_distance(conf).is_err());
1122 }
1123
1124 #[test]
1125 fn test_detect_alphabet_dna() {
1126 let msa = vec![
1127 SequenceObject {
1128 identifier: "Seq0".into(),
1129 sequence: "ACGTACGT".into(),
1130 },
1131 SequenceObject {
1132 identifier: "Seq1".into(),
1133 sequence: "ACG-ACGT".into(),
1134 },
1135 ];
1136 assert_eq!(detect_alphabet(&msa), Alphabet::DNA);
1137 }
1138
1139 #[test]
1140 fn test_detect_alphabet_dna_iupac() {
1141 let msa = vec![SequenceObject {
1143 identifier: "Seq0".into(),
1144 sequence: "ACGTRYWSMKHBDVNU".into(),
1145 }];
1146 assert_eq!(detect_alphabet(&msa), Alphabet::DNA);
1147 }
1148
1149 #[test]
1150 fn test_detect_alphabet_protein() {
1151 let msa = vec![
1152 SequenceObject {
1153 identifier: "Seq0".into(),
1154 sequence: "ACDEFGHIK".into(),
1155 },
1156 SequenceObject {
1157 identifier: "Seq1".into(),
1158 sequence: "ACD-FGHIK".into(),
1159 },
1160 ];
1161 assert_eq!(detect_alphabet(&msa), Alphabet::Protein);
1162 }
1163}
1164
1165#[cfg(all(test, feature = "parallel"))]
1166mod parallel_tests {
1167 use super::*;
1168 use crate::models::SubstitutionModel;
1169 use std::sync::Arc;
1170 use std::sync::atomic::{AtomicUsize, Ordering};
1171
1172 fn three_seq_dna() -> Vec<SequenceObject> {
1173 vec![
1174 SequenceObject {
1175 identifier: "A".into(),
1176 sequence: "ACGTACGT".into(),
1177 },
1178 SequenceObject {
1179 identifier: "B".into(),
1180 sequence: "ACGCACGT".into(),
1181 },
1182 SequenceObject {
1183 identifier: "C".into(),
1184 sequence: "ACGTACGC".into(),
1185 },
1186 ]
1187 }
1188
1189 fn base_conf(msa: Vec<SequenceObject>, n_bootstrap_samples: usize) -> NJConfig {
1190 NJConfig {
1191 msa,
1192 n_bootstrap_samples,
1193 substitution_model: SubstitutionModel::PDiff,
1194 alphabet: None,
1195 num_threads: None,
1196 return_distance_matrix: false,
1197 return_average_distance: false,
1198 }
1199 }
1200
1201 #[test]
1202 fn test_parallel_bootstrap_returns_valid_newick() {
1203 let result = nj(base_conf(three_seq_dna(), 20), None).expect("parallel NJ failed");
1204 assert!(result.newick.ends_with(';'));
1205 assert!(result.newick.contains(':'));
1206 }
1207
1208 #[test]
1209 fn test_parallel_progress_fires_exactly_n_times() {
1210 let n = 10_usize;
1211 let count = Arc::new(AtomicUsize::new(0));
1212 let count2 = count.clone();
1213 let cb: Box<dyn Fn(NJEvent)> = Box::new(move |event| {
1214 if let NJEvent::BootstrapProgress { .. } = event {
1215 count2.fetch_add(1, Ordering::SeqCst);
1216 }
1217 });
1218 nj(base_conf(three_seq_dna(), n), Some(cb)).unwrap();
1219 assert_eq!(count.load(Ordering::SeqCst), n);
1220 }
1221
1222 #[test]
1223 fn test_parallel_progress_last_call_is_total() {
1224 let n = 8_usize;
1225 let last = Arc::new(AtomicUsize::new(0));
1226 let last2 = last.clone();
1227 let cb: Box<dyn Fn(NJEvent)> = Box::new(move |event| {
1228 if let NJEvent::BootstrapProgress { completed, .. } = event {
1229 last2.store(completed, Ordering::SeqCst);
1230 }
1231 });
1232 nj(base_conf(three_seq_dna(), n), Some(cb)).unwrap();
1233 assert_eq!(last.load(Ordering::SeqCst), n);
1234 }
1235}