Skip to main content

nj/
lib.rs

1//! Neighbor-Joining phylogenetic tree inference library.
2//!
3//! # Data flow
4//!
5//! ```text
6//! [FASTA / Python dict / JS object]
7//!         │
8//!         ▼
9//!      NJConfig  (config.rs)
10//!         │
11//!         ▼
12//!   detect_alphabet()  ──►  Alphabet::DNA | Alphabet::Protein
13//!         │
14//!         ▼
15//!     MSA<DNA|Protein>  (msa.rs)
16//!      ├── bootstrap() ──► bootstrap_clade_counts()
17//!      └── into_dist::<Model>()
18//!               │
19//!               ▼
20//!           DistMat  (dist.rs)
21//!               │
22//!               ▼
23//!         neighbor_joining()  ──►  NJState::run()  (nj.rs)
24//!               │
25//!               ▼
26//!           TreeNode  (tree.rs)
27//!               │
28//!               ▼
29//!           to_newick()  ──►  Newick String
30//! ```
31//!
32//! # Public API
33//!
34//! The single public entry point is [`nj`], which accepts an [`NJConfig`] and
35//! returns a Newick string. Everything else is internal implementation detail
36//! exposed only to the Python and WASM wrapper crates.
37//!
38//! # Model–alphabet compatibility
39//!
40//! | Model | DNA | Protein |
41//! |-------|-----|---------|
42//! | `PDiff` | ✓ | ✓ |
43//! | `JukesCantor` | ✓ | — |
44//! | `Kimura2P` | ✓ | — |
45//! | `Poisson` | — | ✓ |
46//!
47//! Providing an incompatible model returns an `Err` from [`nj`].
48pub mod alphabet;
49pub mod config;
50pub mod distance_matrix;
51pub mod error;
52pub mod event;
53pub mod fasta;
54pub mod models;
55pub mod msa;
56pub mod nj;
57pub mod tree;
58
59use bitvec::prelude::{BitVec, Lsb0, bitvec};
60use std::collections::HashMap;
61
62use crate::alphabet::{Alphabet, AlphabetEncoding, DNA, Protein};
63use crate::config::SubstitutionModel;
64pub use crate::config::{DistConfig, MSA, NJConfig, NJResult, SequenceObject};
65use crate::distance_matrix::DistMat;
66pub use crate::distance_matrix::DistanceResult;
67pub use crate::error::NJError;
68pub use crate::event::{LogLevel, NJEvent};
69pub use crate::fasta::parse_fasta;
70use crate::models::{JukesCantor, Kimura2P, ModelCalculation, PDiff, Poisson};
71use crate::tree::{NameOrSupport, TreeNode};
72
73/// Fills `out` with the leaf indices of all taxa in the subtree rooted at `node`.
74///
75/// Bits in `out` are set to `true` for each leaf encountered. The bit position
76/// is looked up in `idx` by the leaf's name label. Returns `Err` if a leaf
77/// has no name label (should not occur for well-formed NJ trees).
78fn bitset_of(
79    node: &TreeNode,
80    idx: &HashMap<String, usize>,
81    out: &mut BitVec<u8, Lsb0>,
82) -> Result<(), String> {
83    match &node.children {
84        None => match &node.label {
85            Some(NameOrSupport::Name(name)) => {
86                let i = idx[name];
87                out.set(i, true);
88                Ok(())
89            }
90            _ => Err("Leaf node without a name label".into()),
91        },
92        Some([l, r]) => {
93            bitset_of(l, idx, out)?;
94            bitset_of(r, idx, out)?;
95            Ok(())
96        }
97    }
98}
99
100/// Recursively counts how many times each non-trivial clade appears in `tree`.
101///
102/// A clade is represented as a raw-byte encoding of a `BitVec` over the `n_taxa`
103/// leaf indices. Only clades with `1 < size < n_taxa` (i.e. proper internal
104/// clades) are counted. Each call increments the clade's entry in `counter` by 1.
105/// Used by [`bootstrap_clade_counts`] to aggregate over bootstrap replicates.
106fn count_clades(
107    tree: &TreeNode,
108    idx: &HashMap<String, usize>,
109    n_taxa: usize,
110    counter: &mut HashMap<Vec<u8>, usize>,
111) -> Result<(), String> {
112    if let Some([l, r]) = &tree.children {
113        let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
114        bitset_of(tree, idx, &mut bv)?;
115
116        let n = bv.count_ones();
117        if n > 1 && n < n_taxa {
118            counter
119                .entry(bv.as_raw_slice().to_vec())
120                .and_modify(|c| *c += 1)
121                .or_insert(1);
122        }
123
124        count_clades(l, idx, n_taxa, counter)?;
125        count_clades(r, idx, n_taxa, counter)?;
126    }
127    Ok(())
128}
129
130/// Builds a Rayon thread pool with `num_threads` workers.
131///
132/// When `num_threads` is `None`, Rayon uses its default (one thread per logical CPU).
133/// Returns `Err` if the pool cannot be constructed.
134#[cfg(feature = "parallel")]
135pub(crate) fn build_thread_pool(num_threads: Option<usize>) -> Result<rayon::ThreadPool, String> {
136    let mut builder = rayon::ThreadPoolBuilder::new();
137    if let Some(n) = num_threads {
138        builder = builder.num_threads(n);
139    }
140    builder.build().map_err(|e| e.to_string())
141}
142
143/// Parallel bootstrap worker: runs all replicates on Rayon threads and sends
144/// per-replicate clade maps over an MPSC channel. The main thread (caller)
145/// merges results and fires `on_event` — keeping the callback on a single
146/// thread with no `Sync` requirement.
147///
148/// Uses `std::thread::scope` so `msa` and `idx_map` can be borrowed into the
149/// spawned thread without requiring `'static` lifetimes.
150#[cfg(feature = "parallel")]
151fn bootstrap_clade_counts_parallel<A, M>(
152    msa: &MSA<A>,
153    n_bootstrap_samples: usize,
154    idx_map: &HashMap<String, usize>,
155    n_taxa: usize,
156    on_event: Option<&dyn Fn(NJEvent)>,
157    num_threads: Option<usize>,
158) -> Result<HashMap<Vec<u8>, usize>, String>
159where
160    A: AlphabetEncoding + Send + Sync,
161    A::Symbol: Send + Sync,
162    M: ModelCalculation<A> + Send + Sync,
163{
164    use rayon::iter::{IntoParallelIterator, ParallelIterator};
165    use std::sync::mpsc;
166
167    let pool = build_thread_pool(num_threads)?;
168    let (tx, rx) = mpsc::channel::<Result<HashMap<Vec<u8>, usize>, String>>();
169    let mut counter: HashMap<Vec<u8>, usize> = HashMap::new();
170
171    std::thread::scope(|scope| -> Result<(), String> {
172        // Spawn the Rayon work onto a background OS thread so the main thread
173        // can consume from `rx` concurrently, giving real per-replicate progress.
174        // `pool.install` installs the thread pool as the active pool for the
175        // duration of the closure, bounding parallelism to `num_threads` workers.
176        scope.spawn(|| {
177            pool.install(|| {
178                (0..n_bootstrap_samples)
179                    .into_par_iter()
180                    .for_each_with(tx, |sender, _| {
181                        let result: Result<HashMap<Vec<u8>, usize>, String> = (|| {
182                            let tree = msa
183                                .bootstrap()?
184                                .into_dist::<M>()
185                                .neighbor_joining()
186                                .expect("NJ bootstrap iteration failed");
187                            let mut local = HashMap::new();
188                            count_clades(&tree, idx_map, n_taxa, &mut local)?;
189                            Ok(local)
190                        })(
191                        );
192                        // Ignore send errors: only occurs if receiver was dropped,
193                        // which cannot happen while we are in the recv loop below.
194                        let _ = sender.send(result);
195                    });
196            });
197        });
198
199        // Main thread: receive results as they arrive, merge, and fire progress.
200        for completed in 1..=n_bootstrap_samples {
201            match rx.recv() {
202                Ok(Ok(local)) => {
203                    for (clade, count) in local {
204                        *counter.entry(clade).or_insert(0) += count;
205                    }
206                }
207                Ok(Err(e)) => return Err(e),
208                Err(_) => return Err("bootstrap channel closed unexpectedly".into()),
209            }
210            if let Some(cb) = on_event {
211                cb(NJEvent::BootstrapProgress {
212                    completed,
213                    total: n_bootstrap_samples,
214                });
215            }
216        }
217        Ok(())
218    })?;
219
220    Ok(counter)
221}
222
223/// Performs bootstrap sampling and counts clades across bootstrap trees.
224///
225/// When the `parallel` feature is enabled, replicates run on the Rayon thread
226/// pool via [`bootstrap_clade_counts_parallel`]; the `on_event` callback is
227/// still called after each replicate from the main thread. Without the feature
228/// the loop is sequential and `on_event` fires in order.
229fn bootstrap_clade_counts<A, M>(
230    msa: &MSA<A>,
231    n_bootstrap_samples: usize,
232    on_event: Option<&dyn Fn(NJEvent)>,
233    num_threads: Option<usize>,
234) -> Result<Option<HashMap<Vec<u8>, usize>>, String>
235where
236    A: AlphabetEncoding + Send + Sync,
237    A::Symbol: Send + Sync,
238    M: ModelCalculation<A> + Send + Sync,
239{
240    if n_bootstrap_samples == 0 {
241        return Ok(None);
242    }
243    if let Some(cb) = on_event {
244        cb(NJEvent::BootstrapStarted {
245            total: n_bootstrap_samples,
246        });
247    }
248    let idx_map: HashMap<String, usize> = msa.to_index_map();
249    let n_taxa = msa.n_sequences;
250
251    #[cfg(feature = "parallel")]
252    let counter = bootstrap_clade_counts_parallel::<A, M>(
253        msa,
254        n_bootstrap_samples,
255        &idx_map,
256        n_taxa,
257        on_event,
258        num_threads,
259    )?;
260
261    #[cfg(not(feature = "parallel"))]
262    let counter = {
263        let _ = num_threads;
264        let mut c = HashMap::new();
265        for i in 0..n_bootstrap_samples {
266            let tree = msa
267                .bootstrap()?
268                .into_dist::<M>()
269                .neighbor_joining()
270                .expect("NJ bootstrap iteration failed");
271            count_clades(&tree, &idx_map, n_taxa, &mut c)?;
272            if let Some(cb) = on_event {
273                cb(NJEvent::BootstrapProgress {
274                    completed: i + 1,
275                    total: n_bootstrap_samples,
276                });
277            }
278        }
279        c
280    };
281
282    Ok(Some(counter))
283}
284
285/// Annotates internal nodes with bootstrap support values from `counts`.
286///
287/// For each internal node, computes its clade `BitVec`, looks up the count in
288/// `counts`, normalises it to a percentage (`count * 100 / n_bootstrap_samples`),
289/// and assigns a [`NameOrSupport::Support`] label if a matching entry is found.
290/// Nodes whose clade was never observed in bootstrap replicates receive no label.
291fn add_bootstrap_to_tree(
292    node: &mut TreeNode,
293    idx: &HashMap<String, usize>,
294    n_taxa: usize,
295    counts: &HashMap<Vec<u8>, usize>,
296    n_bootstrap_samples: usize,
297) -> Result<(), String> {
298    if node.children.is_some() {
299        let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
300        bitset_of(node, idx, &mut bv)?;
301
302        let n = bv.count_ones();
303        if n > 1 && n < n_taxa {
304            if let Some(c) = counts.get(&bv.as_raw_slice().to_vec()) {
305                let pct = c * 100 / n_bootstrap_samples;
306                node.label = Some(NameOrSupport::Support(pct));
307            }
308        }
309
310        if let Some([l, r]) = &mut node.children {
311            add_bootstrap_to_tree(l, idx, n_taxa, counts, n_bootstrap_samples)?;
312            add_bootstrap_to_tree(r, idx, n_taxa, counts, n_bootstrap_samples)?;
313        }
314    }
315    Ok(())
316}
317
318/// Validates that the MSA is non-empty and all sequences have equal length.
319fn validate_msa(msa: &[SequenceObject]) -> Result<(), NJError> {
320    if msa.is_empty() {
321        return Err(NJError::EmptyMsa);
322    }
323    let expected_len = msa[0].sequence.len();
324    if expected_len == 0 {
325        return Err(NJError::EmptySequence);
326    }
327    for s in msa {
328        if s.sequence.len() != expected_len {
329            return Err(NJError::SequenceLengthMismatch {
330                expected: expected_len,
331                got: s.sequence.len(),
332                identifier: s.identifier.clone(),
333            });
334        }
335    }
336    Ok(())
337}
338
339/// Heuristically detects whether the MSA contains DNA or protein sequences.
340///
341/// Returns [`Alphabet::DNA`] unless any sequence contains a byte that is not
342/// in the DNA character set (case-insensitive), in which case
343/// [`Alphabet::Protein`] is returned. The DNA set includes standard bases
344/// `{A, C, G, T}`, uridine `U` (RNA), `N` (unknown), gap `-`, and all 11
345/// IUPAC ambiguity codes `{R, Y, S, W, K, M, B, D, H, V}`.
346fn detect_alphabet(msa: &[SequenceObject]) -> Alphabet {
347    let mut is_protein = false;
348
349    'outer: for seq in msa {
350        for c in seq.sequence.bytes() {
351            match c.to_ascii_uppercase() {
352                b'A' | b'C' | b'G' | b'T' | b'U' | b'N' | b'-' | b'R' | b'Y' | b'S' | b'W'
353                | b'K' | b'M' | b'B' | b'D' | b'H' | b'V' => { /* still possible DNA */ }
354                _ => {
355                    is_protein = true;
356                    break 'outer;
357                }
358            }
359        }
360    }
361
362    if is_protein {
363        Alphabet::Protein
364    } else {
365        Alphabet::DNA
366    }
367}
368
369/// Runs distance matrix computation with model `M` on alphabet `A`.
370fn run_distance_matrix<A, M>(
371    msa: MSA<A>,
372    num_threads: Option<usize>,
373) -> Result<DistanceResult, String>
374where
375    A: AlphabetEncoding + Send + Sync,
376    A::Symbol: Send + Sync,
377    M: ModelCalculation<A> + Send + Sync,
378{
379    #[cfg(feature = "parallel")]
380    {
381        let pool = build_thread_pool(num_threads)?;
382        Ok(pool.install(|| msa.into_dist::<M>()).into_result())
383    }
384    #[cfg(not(feature = "parallel"))]
385    {
386        let _ = num_threads;
387        Ok(msa.into_dist::<M>().into_result())
388    }
389}
390
391/// Runs average distance computation with model `M` on alphabet `A`.
392fn run_average_distance<A, M>(msa: MSA<A>, num_threads: Option<usize>) -> Result<f64, String>
393where
394    A: AlphabetEncoding + Send + Sync,
395    A::Symbol: Send + Sync,
396    M: ModelCalculation<A> + Send + Sync,
397{
398    #[cfg(feature = "parallel")]
399    {
400        let pool = build_thread_pool(num_threads)?;
401        Ok(pool.install(|| msa.into_dist::<M>()).average())
402    }
403    #[cfg(not(feature = "parallel"))]
404    {
405        let _ = num_threads;
406        Ok(msa.into_dist::<M>().average())
407    }
408}
409
410/// Runs NJ with model `M` on alphabet `A` and returns an [`NJResult`].
411///
412/// If `n_bootstrap_samples > 0`, generates that many bootstrap replicates,
413/// collects clade counts via [`bootstrap_clade_counts`], runs NJ on the
414/// original distances, and annotates the tree before serialising to Newick.
415/// When `include_distance_matrix` or `include_average_distance` is `true`,
416/// the corresponding values are captured from the distance matrix before NJ
417/// consumes it and included in the returned [`NJResult`].
418fn run_nj<A, M>(
419    msa: MSA<A>,
420    n_bootstrap_samples: usize,
421    on_event: Option<&dyn Fn(NJEvent)>,
422    num_threads: Option<usize>,
423    include_distance_matrix: bool,
424    include_average_distance: bool,
425) -> Result<NJResult, String>
426where
427    A: AlphabetEncoding + Send + Sync,
428    A::Symbol: Send + Sync,
429    M: ModelCalculation<A> + Send + Sync,
430{
431    let clade_counts =
432        bootstrap_clade_counts::<A, M>(&msa, n_bootstrap_samples, on_event, num_threads)?;
433
434    if let Some(cb) = on_event {
435        cb(NJEvent::ComputingDistances);
436    }
437
438    #[cfg(feature = "parallel")]
439    let dist = {
440        let pool = build_thread_pool(num_threads)?;
441        pool.install(|| msa.into_dist::<M>())
442    };
443    #[cfg(not(feature = "parallel"))]
444    let dist = msa.into_dist::<M>();
445
446    let distance_matrix = if include_distance_matrix {
447        Some(dist.to_result())
448    } else {
449        None
450    };
451    let average_distance = if include_average_distance {
452        Some(dist.average())
453    } else {
454        None
455    };
456
457    if let Some(cb) = on_event {
458        cb(NJEvent::RunningNJ);
459    }
460
461    let mut main_tree = dist.neighbor_joining()?;
462
463    let newick = match clade_counts {
464        Some(counts) => {
465            if let Some(cb) = on_event {
466                cb(NJEvent::AnnotatingBootstrap);
467            }
468            let main_idx_map: HashMap<String, usize> = msa.to_index_map();
469            add_bootstrap_to_tree(
470                &mut main_tree,
471                &main_idx_map,
472                msa.n_sequences,
473                &counts,
474                n_bootstrap_samples,
475            )?;
476            main_tree.to_newick()
477        }
478        None => main_tree.to_newick(),
479    };
480    Ok(NJResult {
481        newick,
482        distance_matrix,
483        average_distance,
484    })
485}
486
487/// Infers a phylogenetic tree from an aligned MSA and returns an [`NJResult`].
488///
489/// This is the single public entry point for the library. The alphabet is
490/// auto-detected from the sequences unless `conf.alphabet` is set; `conf.substitution_model`
491/// must be compatible with the alphabet (see the module-level compatibility
492/// table). Returns `Err` for an empty MSA, an incompatible model, or any
493/// internal NJ failure.
494///
495/// Set [`NJConfig::include_distance_matrix`] and/or
496/// [`NJConfig::include_average_distance`] to `true` to include those values in
497/// the returned [`NJResult`] alongside the Newick tree.
498///
499/// `on_event` is called with an [`NJEvent`] at each stage of the algorithm.
500/// Bootstrap progress is reported via [`NJEvent::BootstrapProgress`] after
501/// each replicate. Pass `None` if event reporting is not needed.
502pub fn nj(conf: NJConfig, on_event: Option<Box<dyn Fn(NJEvent)>>) -> Result<NJResult, NJError> {
503    let cb = on_event.as_deref();
504    let num_threads = conf.num_threads;
505    let include_distance_matrix = conf.return_distance_matrix;
506    let include_average_distance = conf.return_average_distance;
507    validate_msa(&conf.msa)?;
508    let n_sites = conf.msa[0].sequence.len();
509    if let Some(cb) = cb {
510        cb(NJEvent::MsaValidated {
511            n_sequences: conf.msa.len(),
512            n_sites,
513        });
514    }
515    let alphabet = conf.alphabet.unwrap_or_else(|| detect_alphabet(&conf.msa));
516    if let Some(cb) = cb {
517        cb(NJEvent::AlphabetDetected {
518            alphabet: alphabet.clone(),
519        });
520    }
521    let model = conf.substitution_model;
522    match alphabet {
523        Alphabet::DNA => {
524            let msa =
525                MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
526            match model {
527                SubstitutionModel::PDiff => run_nj::<DNA, PDiff>(
528                    msa,
529                    conf.n_bootstrap_samples,
530                    cb,
531                    num_threads,
532                    include_distance_matrix,
533                    include_average_distance,
534                )
535                .map_err(NJError::AlgorithmFailure),
536                SubstitutionModel::JukesCantor => run_nj::<DNA, JukesCantor>(
537                    msa,
538                    conf.n_bootstrap_samples,
539                    cb,
540                    num_threads,
541                    include_distance_matrix,
542                    include_average_distance,
543                )
544                .map_err(NJError::AlgorithmFailure),
545                SubstitutionModel::Kimura2P => run_nj::<DNA, Kimura2P>(
546                    msa,
547                    conf.n_bootstrap_samples,
548                    cb,
549                    num_threads,
550                    include_distance_matrix,
551                    include_average_distance,
552                )
553                .map_err(NJError::AlgorithmFailure),
554                SubstitutionModel::Poisson => Err(NJError::IncompatibleModel {
555                    model,
556                    alphabet: Alphabet::DNA,
557                }),
558            }
559        }
560        Alphabet::Protein => {
561            let msa =
562                MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
563            match model {
564                SubstitutionModel::Poisson => run_nj::<Protein, Poisson>(
565                    msa,
566                    conf.n_bootstrap_samples,
567                    cb,
568                    num_threads,
569                    include_distance_matrix,
570                    include_average_distance,
571                )
572                .map_err(NJError::AlgorithmFailure),
573                SubstitutionModel::PDiff => run_nj::<Protein, PDiff>(
574                    msa,
575                    conf.n_bootstrap_samples,
576                    cb,
577                    num_threads,
578                    include_distance_matrix,
579                    include_average_distance,
580                )
581                .map_err(NJError::AlgorithmFailure),
582                SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
583                    Err(NJError::IncompatibleModel {
584                        model,
585                        alphabet: Alphabet::Protein,
586                    })
587                }
588            }
589        }
590    }
591}
592
593/// Computes pairwise distances from an aligned MSA and returns a [`DistanceResult`].
594///
595/// The alphabet is auto-detected from the sequences unless `conf.alphabet` is set;
596/// `conf.substitution_model` must be compatible with the alphabet (see the
597/// module-level compatibility table). Returns `Err` for an empty MSA, incompatible
598/// model, or mismatched sequence lengths. Does not run Neighbor-Joining or bootstrapping.
599pub fn distance_matrix(conf: DistConfig) -> Result<DistanceResult, NJError> {
600    let num_threads = conf.num_threads;
601    validate_msa(&conf.msa)?;
602    let alphabet = conf.alphabet.unwrap_or_else(|| detect_alphabet(&conf.msa));
603    let model = conf.substitution_model;
604    match alphabet {
605        Alphabet::DNA => {
606            let msa =
607                MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
608            match model {
609                SubstitutionModel::PDiff => run_distance_matrix::<DNA, PDiff>(msa, num_threads)
610                    .map_err(NJError::AlgorithmFailure),
611                SubstitutionModel::JukesCantor => {
612                    run_distance_matrix::<DNA, JukesCantor>(msa, num_threads)
613                        .map_err(NJError::AlgorithmFailure)
614                }
615                SubstitutionModel::Kimura2P => {
616                    run_distance_matrix::<DNA, Kimura2P>(msa, num_threads)
617                        .map_err(NJError::AlgorithmFailure)
618                }
619                SubstitutionModel::Poisson => Err(NJError::IncompatibleModel {
620                    model,
621                    alphabet: Alphabet::DNA,
622                }),
623            }
624        }
625        Alphabet::Protein => {
626            let msa =
627                MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
628            match model {
629                SubstitutionModel::Poisson => {
630                    run_distance_matrix::<Protein, Poisson>(msa, num_threads)
631                        .map_err(NJError::AlgorithmFailure)
632                }
633                SubstitutionModel::PDiff => run_distance_matrix::<Protein, PDiff>(msa, num_threads)
634                    .map_err(NJError::AlgorithmFailure),
635                SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
636                    Err(NJError::IncompatibleModel {
637                        model,
638                        alphabet: Alphabet::Protein,
639                    })
640                }
641            }
642        }
643    }
644}
645
646/// Computes the mean of all `n*(n-1)/2` unique pairwise distances.
647///
648/// Same alphabet auto-detection and model–alphabet compatibility as [`nj`].
649/// Returns `0.0` for fewer than 2 taxa. Returns `Err` for an empty MSA,
650/// incompatible model, or mismatched sequence lengths.
651pub fn average_distance(conf: DistConfig) -> Result<f64, NJError> {
652    let num_threads = conf.num_threads;
653    validate_msa(&conf.msa)?;
654    let alphabet = conf.alphabet.unwrap_or_else(|| detect_alphabet(&conf.msa));
655    let model = conf.substitution_model;
656    match alphabet {
657        Alphabet::DNA => {
658            let msa =
659                MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
660            match model {
661                SubstitutionModel::PDiff => run_average_distance::<DNA, PDiff>(msa, num_threads)
662                    .map_err(NJError::AlgorithmFailure),
663                SubstitutionModel::JukesCantor => {
664                    run_average_distance::<DNA, JukesCantor>(msa, num_threads)
665                        .map_err(NJError::AlgorithmFailure)
666                }
667                SubstitutionModel::Kimura2P => {
668                    run_average_distance::<DNA, Kimura2P>(msa, num_threads)
669                        .map_err(NJError::AlgorithmFailure)
670                }
671                SubstitutionModel::Poisson => Err(NJError::IncompatibleModel {
672                    model,
673                    alphabet: Alphabet::DNA,
674                }),
675            }
676        }
677        Alphabet::Protein => {
678            let msa =
679                MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
680            match model {
681                SubstitutionModel::Poisson => {
682                    run_average_distance::<Protein, Poisson>(msa, num_threads)
683                        .map_err(NJError::AlgorithmFailure)
684                }
685                SubstitutionModel::PDiff => {
686                    run_average_distance::<Protein, PDiff>(msa, num_threads)
687                        .map_err(NJError::AlgorithmFailure)
688                }
689                SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
690                    Err(NJError::IncompatibleModel {
691                        model,
692                        alphabet: Alphabet::Protein,
693                    })
694                }
695            }
696        }
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use crate::config::DistConfig;
704    use crate::models::SubstitutionModel;
705
706    /// Builds a minimal [`NJConfig`] with both include flags set to the given values.
707    fn nj_conf(pairs: &[(&str, &str)], include_dm: bool, include_avg: bool) -> NJConfig {
708        NJConfig {
709            msa: pairs
710                .iter()
711                .map(|(id, seq)| SequenceObject {
712                    identifier: id.to_string(),
713                    sequence: seq.to_string(),
714                })
715                .collect(),
716            n_bootstrap_samples: 0,
717            substitution_model: SubstitutionModel::PDiff,
718            alphabet: None,
719            num_threads: None,
720            return_distance_matrix: include_dm,
721            return_average_distance: include_avg,
722        }
723    }
724
725    #[test]
726    fn test_nj_wrapper_simple_tree() {
727        // ACGTCG vs ACG-GC: pos 3 is gapped (excluded), 2 diffs out of 5 comparable
728        // → distance = 0.4; two taxa → each branch = 0.2.
729        let sequences = vec![
730            SequenceObject {
731                identifier: "A".into(),
732                sequence: "ACGTCG".into(),
733            },
734            SequenceObject {
735                identifier: "B".into(),
736                sequence: "ACG-GC".into(),
737            },
738        ];
739        let conf = NJConfig {
740            msa: sequences,
741            n_bootstrap_samples: 0,
742            substitution_model: SubstitutionModel::PDiff,
743            alphabet: None,
744            num_threads: None,
745            return_distance_matrix: false,
746            return_average_distance: false,
747        };
748        let result = nj(conf, None).expect("NJ failed");
749        assert_eq!(result.newick, "(A:0.200,B:0.200);");
750    }
751
752    #[test]
753    fn test_nj_wrapper_adds_semicolon() {
754        let sequences = vec![
755            SequenceObject {
756                identifier: "Seq0".into(),
757                sequence: "A".into(),
758            },
759            SequenceObject {
760                identifier: "Seq1".into(),
761                sequence: "A".into(),
762            },
763        ];
764        let conf = NJConfig {
765            msa: sequences,
766            n_bootstrap_samples: 0,
767            substitution_model: SubstitutionModel::PDiff,
768            alphabet: None,
769            num_threads: None,
770            return_distance_matrix: false,
771            return_average_distance: false,
772        };
773        let out = nj(conf, None).unwrap();
774        assert!(out.newick.ends_with(';'));
775    }
776
777    #[test]
778    fn test_nj_deterministic_order() {
779        let sequences = vec![
780            SequenceObject {
781                identifier: "Seq0".into(),
782                sequence: "ACGTCG".into(),
783            },
784            SequenceObject {
785                identifier: "Seq1".into(),
786                sequence: "ACG-GC".into(),
787            },
788            SequenceObject {
789                identifier: "Seq2".into(),
790                sequence: "ACGCGT".into(),
791            },
792        ];
793        let conf = NJConfig {
794            msa: sequences,
795            n_bootstrap_samples: 0,
796            substitution_model: SubstitutionModel::PDiff,
797            alphabet: None,
798            num_threads: None,
799            return_distance_matrix: false,
800            return_average_distance: false,
801        };
802
803        let t1 = nj(conf.clone(), None).unwrap();
804        let t2 = nj(conf, None).unwrap();
805        assert_eq!(t1, t2);
806    }
807
808    #[test]
809    fn test_nj_wrapper_empty_msa() {
810        let conf = NJConfig {
811            msa: vec![],
812            n_bootstrap_samples: 0,
813            substitution_model: SubstitutionModel::PDiff,
814            alphabet: None,
815            num_threads: None,
816            return_distance_matrix: false,
817            return_average_distance: false,
818        };
819        let result = nj(conf, None);
820        assert!(result.is_err());
821    }
822
823    #[test]
824    fn test_nj_wrapper_incorrect_model_for_alphabet() {
825        let sequences = vec![
826            SequenceObject {
827                identifier: "Seq0".into(),
828                sequence: "ACGTCG".into(),
829            },
830            SequenceObject {
831                identifier: "Seq1".into(),
832                sequence: "ACG-GC".into(),
833            },
834        ];
835        let conf = NJConfig {
836            msa: sequences,
837            n_bootstrap_samples: 0,
838            substitution_model: SubstitutionModel::Poisson, // protein model for DNA MSA
839            alphabet: None,
840            num_threads: None,
841            return_distance_matrix: false,
842            return_average_distance: false,
843        };
844        let result = nj(conf, None);
845        assert!(result.is_err());
846    }
847
848    #[test]
849    fn test_nj_wrapper_incorrect_model_for_protein() {
850        let sequences = vec![
851            SequenceObject {
852                identifier: "Seq0".into(),
853                sequence: "ACDEFGH".into(),
854            },
855            SequenceObject {
856                identifier: "Seq1".into(),
857                sequence: "ACD-FGH".into(),
858            },
859        ];
860        let conf = NJConfig {
861            msa: sequences,
862            n_bootstrap_samples: 0,
863            substitution_model: SubstitutionModel::JukesCantor, // DNA model for protein MSA
864            alphabet: None,
865            num_threads: None,
866            return_distance_matrix: false,
867            return_average_distance: false,
868        };
869        let result = nj(conf, None);
870        assert!(result.is_err());
871    }
872
873    // --- NJResult optional fields ---
874
875    #[test]
876    fn test_nj_result_no_extras_by_default() {
877        let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], false, false), None).unwrap();
878        assert!(result.distance_matrix.is_none());
879        assert!(result.average_distance.is_none());
880    }
881
882    #[test]
883    fn test_nj_result_include_distance_matrix() {
884        let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], true, false), None).unwrap();
885        let dm = result
886            .distance_matrix
887            .expect("distance_matrix should be present");
888        assert_eq!(dm.names, vec!["A", "B"]);
889        assert_eq!(dm.matrix.len(), 2);
890        assert_eq!(dm.matrix[0][0], 0.0);
891        assert!((dm.matrix[0][1] - 0.25).abs() < 1e-12);
892        assert!(result.average_distance.is_none());
893    }
894
895    #[test]
896    fn test_nj_result_include_average_distance() {
897        let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], false, true), None).unwrap();
898        let avg = result
899            .average_distance
900            .expect("average_distance should be present");
901        assert!((avg - 0.25).abs() < 1e-12);
902        assert!(result.distance_matrix.is_none());
903    }
904
905    #[test]
906    fn test_nj_result_include_both() {
907        let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], true, true), None).unwrap();
908        let dm = result
909            .distance_matrix
910            .as_ref()
911            .expect("distance_matrix should be present");
912        let avg = result
913            .average_distance
914            .expect("average_distance should be present");
915        // distance matrix and average_distance must be consistent
916        assert!((dm.matrix[0][1] - avg).abs() < 1e-12);
917    }
918
919    #[test]
920    fn test_nj_result_distance_matrix_consistent_with_newick() {
921        // Known tree: two taxa with 1/4 difference → branch lengths 0.125 each from NJ
922        let result = nj(nj_conf(&[("A", "ACGT"), ("B", "ACGA")], true, false), None).unwrap();
923        let dm = result.distance_matrix.unwrap();
924        // full matrix should be symmetric with zero diagonal
925        assert_eq!(dm.matrix[0][0], 0.0);
926        assert_eq!(dm.matrix[1][1], 0.0);
927        assert!((dm.matrix[0][1] - dm.matrix[1][0]).abs() < 1e-12);
928    }
929
930    // --- distance_matrix ---
931
932    fn dist_conf(pairs: &[(&str, &str)], model: SubstitutionModel) -> DistConfig {
933        DistConfig {
934            msa: pairs
935                .iter()
936                .map(|(id, seq)| SequenceObject {
937                    identifier: id.to_string(),
938                    sequence: seq.to_string(),
939                })
940                .collect(),
941            substitution_model: model,
942            alphabet: None,
943            num_threads: None,
944        }
945    }
946
947    #[test]
948    fn test_distance_matrix_names_and_shape() {
949        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
950        let result = distance_matrix(conf).unwrap();
951        assert_eq!(result.names, vec!["A", "B"]);
952        assert_eq!(result.matrix.len(), 2);
953        assert_eq!(result.matrix[0].len(), 2);
954        assert_eq!(result.matrix[1].len(), 2);
955    }
956
957    #[test]
958    fn test_distance_matrix_diagonal_zero() {
959        let conf = dist_conf(
960            &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
961            SubstitutionModel::PDiff,
962        );
963        let result = distance_matrix(conf).unwrap();
964        for i in 0..3 {
965            assert_eq!(result.matrix[i][i], 0.0);
966        }
967    }
968
969    #[test]
970    fn test_distance_matrix_symmetric() {
971        let conf = dist_conf(
972            &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
973            SubstitutionModel::PDiff,
974        );
975        let result = distance_matrix(conf).unwrap();
976        for i in 0..3 {
977            for j in 0..3 {
978                assert_eq!(result.matrix[i][j], result.matrix[j][i]);
979            }
980        }
981    }
982
983    #[test]
984    fn test_distance_matrix_pdiff_known_value() {
985        // one difference at position 3 (T vs A) out of 4 → 0.25
986        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
987        let result = distance_matrix(conf).unwrap();
988        assert!((result.matrix[0][1] - 0.25).abs() < 1e-12);
989        assert!((result.matrix[1][0] - 0.25).abs() < 1e-12);
990    }
991
992    #[test]
993    fn test_distance_matrix_identical_sequences_zero() {
994        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGT")], SubstitutionModel::PDiff);
995        let result = distance_matrix(conf).unwrap();
996        assert_eq!(result.matrix[0][1], 0.0);
997    }
998
999    #[test]
1000    fn test_distance_matrix_jukes_cantor_dna() {
1001        let conf = dist_conf(
1002            &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
1003            SubstitutionModel::JukesCantor,
1004        );
1005        let result = distance_matrix(conf).unwrap();
1006        // JC distance for p=0.25: -0.75 * ln(1 - 4/3 * 0.25)
1007        let expected = -0.75_f64 * (1.0_f64 - (4.0_f64 / 3.0) * 0.25).ln();
1008        assert!((result.matrix[0][1] - expected).abs() < 1e-10);
1009    }
1010
1011    #[test]
1012    fn test_distance_matrix_kimura2p_dna() {
1013        let conf = dist_conf(
1014            &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
1015            SubstitutionModel::Kimura2P,
1016        );
1017        let result = distance_matrix(conf).unwrap();
1018        assert_eq!(result.names, vec!["A", "B", "C"]);
1019        assert!(result.matrix[0][0] == 0.0);
1020    }
1021
1022    #[test]
1023    fn test_distance_matrix_poisson_protein() {
1024        let conf = dist_conf(
1025            &[("A", "ACDEFGH"), ("B", "ACDEFGK")],
1026            SubstitutionModel::Poisson,
1027        );
1028        let result = distance_matrix(conf).unwrap();
1029        // 1 diff (H vs K) out of 7: p=1/7, d=-ln(1-1/7)
1030        let expected = -(1.0_f64 - 1.0 / 7.0).ln();
1031        assert!((result.matrix[0][1] - expected).abs() < 1e-10);
1032    }
1033
1034    #[test]
1035    fn test_distance_matrix_pdiff_protein() {
1036        let conf = dist_conf(
1037            &[("A", "ACDEFGH"), ("B", "ACDEFGK")],
1038            SubstitutionModel::PDiff,
1039        );
1040        let result = distance_matrix(conf).unwrap();
1041        assert!((result.matrix[0][1] - 1.0 / 7.0).abs() < 1e-12);
1042    }
1043
1044    #[test]
1045    fn test_distance_matrix_empty_msa_errors() {
1046        let conf = DistConfig {
1047            msa: vec![],
1048            substitution_model: SubstitutionModel::PDiff,
1049            alphabet: None,
1050            num_threads: None,
1051        };
1052        assert!(distance_matrix(conf).is_err());
1053    }
1054
1055    #[test]
1056    fn test_distance_matrix_incompatible_model_errors() {
1057        // Poisson on DNA
1058        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::Poisson);
1059        assert!(distance_matrix(conf).is_err());
1060        // JukesCantor on Protein
1061        let conf = dist_conf(
1062            &[("A", "ACDEFGH"), ("B", "ACDEFGK")],
1063            SubstitutionModel::JukesCantor,
1064        );
1065        assert!(distance_matrix(conf).is_err());
1066    }
1067
1068    // --- average_distance ---
1069
1070    #[test]
1071    fn test_average_distance_identical_sequences_zero() {
1072        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGT")], SubstitutionModel::PDiff);
1073        let avg = average_distance(conf).unwrap();
1074        assert_eq!(avg, 0.0);
1075    }
1076
1077    #[test]
1078    fn test_average_distance_two_taxa_equals_pairwise() {
1079        // one difference out of 4 → 0.25
1080        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::PDiff);
1081        let avg = average_distance(conf).unwrap();
1082        assert!((avg - 0.25).abs() < 1e-12);
1083    }
1084
1085    #[test]
1086    fn test_average_distance_three_taxa_known_value() {
1087        // A↔B: 1/4=0.25 (T→A), A↔C: 1/4=0.25 (C→G), B↔C: 2/4=0.5 → avg = 1/3
1088        let conf = dist_conf(
1089            &[("A", "ACGT"), ("B", "ACGA"), ("C", "AGGT")],
1090            SubstitutionModel::PDiff,
1091        );
1092        let avg = average_distance(conf).unwrap();
1093        assert!((avg - 1.0 / 3.0).abs() < 1e-12);
1094    }
1095
1096    #[test]
1097    fn test_average_distance_jukes_cantor_dna() {
1098        let conf = dist_conf(
1099            &[("A", "ACGT"), ("B", "ACGA")],
1100            SubstitutionModel::JukesCantor,
1101        );
1102        let avg = average_distance(conf).unwrap();
1103        let expected = -0.75_f64 * (1.0_f64 - (4.0_f64 / 3.0) * 0.25).ln();
1104        assert!((avg - expected).abs() < 1e-10);
1105    }
1106
1107    #[test]
1108    fn test_average_distance_empty_msa_errors() {
1109        let conf = DistConfig {
1110            msa: vec![],
1111            substitution_model: SubstitutionModel::PDiff,
1112            alphabet: None,
1113            num_threads: None,
1114        };
1115        assert!(average_distance(conf).is_err());
1116    }
1117
1118    #[test]
1119    fn test_average_distance_incompatible_model_errors() {
1120        let conf = dist_conf(&[("A", "ACGT"), ("B", "ACGA")], SubstitutionModel::Poisson);
1121        assert!(average_distance(conf).is_err());
1122    }
1123
1124    #[test]
1125    fn test_detect_alphabet_dna() {
1126        let msa = vec![
1127            SequenceObject {
1128                identifier: "Seq0".into(),
1129                sequence: "ACGTACGT".into(),
1130            },
1131            SequenceObject {
1132                identifier: "Seq1".into(),
1133                sequence: "ACG-ACGT".into(),
1134            },
1135        ];
1136        assert_eq!(detect_alphabet(&msa), Alphabet::DNA);
1137    }
1138
1139    #[test]
1140    fn test_detect_alphabet_dna_iupac() {
1141        // IUPAC ambiguity codes and RNA U should be detected as DNA, not protein.
1142        let msa = vec![SequenceObject {
1143            identifier: "Seq0".into(),
1144            sequence: "ACGTRYWSMKHBDVNU".into(),
1145        }];
1146        assert_eq!(detect_alphabet(&msa), Alphabet::DNA);
1147    }
1148
1149    #[test]
1150    fn test_detect_alphabet_protein() {
1151        let msa = vec![
1152            SequenceObject {
1153                identifier: "Seq0".into(),
1154                sequence: "ACDEFGHIK".into(),
1155            },
1156            SequenceObject {
1157                identifier: "Seq1".into(),
1158                sequence: "ACD-FGHIK".into(),
1159            },
1160        ];
1161        assert_eq!(detect_alphabet(&msa), Alphabet::Protein);
1162    }
1163}
1164
1165#[cfg(all(test, feature = "parallel"))]
1166mod parallel_tests {
1167    use super::*;
1168    use crate::models::SubstitutionModel;
1169    use std::sync::Arc;
1170    use std::sync::atomic::{AtomicUsize, Ordering};
1171
1172    fn three_seq_dna() -> Vec<SequenceObject> {
1173        vec![
1174            SequenceObject {
1175                identifier: "A".into(),
1176                sequence: "ACGTACGT".into(),
1177            },
1178            SequenceObject {
1179                identifier: "B".into(),
1180                sequence: "ACGCACGT".into(),
1181            },
1182            SequenceObject {
1183                identifier: "C".into(),
1184                sequence: "ACGTACGC".into(),
1185            },
1186        ]
1187    }
1188
1189    fn base_conf(msa: Vec<SequenceObject>, n_bootstrap_samples: usize) -> NJConfig {
1190        NJConfig {
1191            msa,
1192            n_bootstrap_samples,
1193            substitution_model: SubstitutionModel::PDiff,
1194            alphabet: None,
1195            num_threads: None,
1196            return_distance_matrix: false,
1197            return_average_distance: false,
1198        }
1199    }
1200
1201    #[test]
1202    fn test_parallel_bootstrap_returns_valid_newick() {
1203        let result = nj(base_conf(three_seq_dna(), 20), None).expect("parallel NJ failed");
1204        assert!(result.newick.ends_with(';'));
1205        assert!(result.newick.contains(':'));
1206    }
1207
1208    #[test]
1209    fn test_parallel_progress_fires_exactly_n_times() {
1210        let n = 10_usize;
1211        let count = Arc::new(AtomicUsize::new(0));
1212        let count2 = count.clone();
1213        let cb: Box<dyn Fn(NJEvent)> = Box::new(move |event| {
1214            if let NJEvent::BootstrapProgress { .. } = event {
1215                count2.fetch_add(1, Ordering::SeqCst);
1216            }
1217        });
1218        nj(base_conf(three_seq_dna(), n), Some(cb)).unwrap();
1219        assert_eq!(count.load(Ordering::SeqCst), n);
1220    }
1221
1222    #[test]
1223    fn test_parallel_progress_last_call_is_total() {
1224        let n = 8_usize;
1225        let last = Arc::new(AtomicUsize::new(0));
1226        let last2 = last.clone();
1227        let cb: Box<dyn Fn(NJEvent)> = Box::new(move |event| {
1228            if let NJEvent::BootstrapProgress { completed, .. } = event {
1229                last2.store(completed, Ordering::SeqCst);
1230            }
1231        });
1232        nj(base_conf(three_seq_dna(), n), Some(cb)).unwrap();
1233        assert_eq!(last.load(Ordering::SeqCst), n);
1234    }
1235}