Skip to main content

nj/
lib.rs

1//! Neighbor-Joining phylogenetic tree inference library.
2//!
3//! # Data flow
4//!
5//! ```text
6//! [FASTA / Python dict / JS object]
7//!         │
8//!         ▼
9//!      NJConfig  (config.rs)
10//!         │
11//!         ▼
12//!   detect_alphabet()  ──►  Alphabet::DNA | Alphabet::Protein
13//!         │
14//!         ▼
15//!     MSA<DNA|Protein>  (msa.rs)
16//!      ├── bootstrap() ──► bootstrap_clade_counts()
17//!      └── into_dist::<Model>()
18//!               │
19//!               ▼
20//!           DistMat  (dist.rs)
21//!               │
22//!               ▼
23//!         neighbor_joining()  ──►  NJState::run()  (nj.rs)
24//!               │
25//!               ▼
26//!           TreeNode  (tree.rs)
27//!               │
28//!               ▼
29//!           to_newick()  ──►  Newick String
30//! ```
31//!
32//! # Public API
33//!
34//! The single public entry point is [`nj`], which accepts an [`NJConfig`] and
35//! returns a Newick string. Everything else is internal implementation detail
36//! exposed only to the Python and WASM wrapper crates.
37//!
38//! # Model–alphabet compatibility
39//!
40//! | Model | DNA | Protein |
41//! |-------|-----|---------|
42//! | `PDiff` | ✓ | ✓ |
43//! | `JukesCantor` | ✓ | — |
44//! | `Kimura2P` | ✓ | — |
45//! | `Poisson` | — | ✓ |
46//!
47//! Providing an incompatible model returns an `Err` from [`nj`].
48pub mod alphabet;
49pub mod config;
50pub mod distance_matrix;
51pub mod models;
52pub mod msa;
53pub mod nj;
54pub mod tree;
55
56use bitvec::prelude::{BitVec, Lsb0, bitvec};
57use std::collections::HashMap;
58
59use crate::alphabet::{Alphabet, AlphabetEncoding, DNA, Protein};
60use crate::config::SubstitutionModel;
61pub use crate::config::{MSA, NJConfig, SequenceObject};
62use crate::distance_matrix::DistMat;
63use crate::models::{JukesCantor, Kimura2P, ModelCalculation, PDiff, Poisson};
64use crate::tree::{NameOrSupport, TreeNode};
65
66/// Fills `out` with the leaf indices of all taxa in the subtree rooted at `node`.
67///
68/// Bits in `out` are set to `true` for each leaf encountered. The bit position
69/// is looked up in `idx` by the leaf's name label. Returns `Err` if a leaf
70/// has no name label (should not occur for well-formed NJ trees).
71fn bitset_of(
72    node: &TreeNode,
73    idx: &HashMap<String, usize>,
74    out: &mut BitVec<u8, Lsb0>,
75) -> Result<(), String> {
76    match &node.children {
77        None => match &node.label {
78            Some(NameOrSupport::Name(name)) => {
79                let i = idx[name];
80                out.set(i, true);
81                Ok(())
82            }
83            _ => Err("Leaf node without a name label".into()),
84        },
85        Some([l, r]) => {
86            bitset_of(l, idx, out)?;
87            bitset_of(r, idx, out)?;
88            Ok(())
89        }
90    }
91}
92
93/// Recursively counts how many times each non-trivial clade appears in `tree`.
94///
95/// A clade is represented as a raw-byte encoding of a `BitVec` over the `n_taxa`
96/// leaf indices. Only clades with `1 < size < n_taxa` (i.e. proper internal
97/// clades) are counted. Each call increments the clade's entry in `counter` by 1.
98/// Used by [`bootstrap_clade_counts`] to aggregate over bootstrap replicates.
99fn count_clades(
100    tree: &TreeNode,
101    idx: &HashMap<String, usize>,
102    n_taxa: usize,
103    counter: &mut HashMap<Vec<u8>, usize>,
104) -> Result<(), String> {
105    if let Some([l, r]) = &tree.children {
106        // compute clade bitvec
107        let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
108        bitset_of(tree, idx, &mut bv).unwrap();
109
110        let n = bv.count_ones();
111        if n > 1 && n < n_taxa {
112            // unique by structure; no HashSet needed
113            counter
114                .entry(bv.as_raw_slice().to_vec())
115                .and_modify(|c| *c += 1)
116                .or_insert(1);
117        }
118
119        // recursion
120        count_clades(l, idx, n_taxa, counter)?;
121        count_clades(r, idx, n_taxa, counter)?;
122    }
123    Ok(())
124}
125
126/// Performs bootstrap sampling and counts clades across bootstrap trees.
127fn bootstrap_clade_counts<A: AlphabetEncoding, M: ModelCalculation<A>>(
128    msa: &MSA<A>,
129    n_bootstrap_samples: usize,
130) -> Result<Option<HashMap<Vec<u8>, usize>>, String> {
131    if n_bootstrap_samples == 0 {
132        return Ok(None);
133    }
134    let idx_map: HashMap<String, usize> = msa.to_index_map();
135    let mut counter = HashMap::new();
136    for _ in 0..n_bootstrap_samples {
137        let tree = msa
138            .bootstrap()?
139            .into_dist::<M>()
140            .neighbor_joining()
141            .expect("NJ bootstrap iteration failed");
142        count_clades(&tree, &idx_map, msa.n_sequences, &mut counter)?;
143    }
144    Ok(Some(counter))
145}
146
147/// Annotates internal nodes with bootstrap support values from `counts`.
148///
149/// For each internal node, computes its clade `BitVec`, looks up the count in
150/// `counts`, and assigns a [`NameOrSupport::Support`] label if a matching
151/// entry is found. Nodes whose clade was never observed in bootstrap replicates
152/// receive no label.
153fn add_bootstrap_to_tree(
154    node: &mut TreeNode,
155    idx: &HashMap<String, usize>,
156    n_taxa: usize,
157    counts: &HashMap<Vec<u8>, usize>,
158) {
159    if node.children.is_some() {
160        // compute clade bitvec
161        let mut bv = bitvec![u8, Lsb0; 0; n_taxa];
162        bitset_of(node, idx, &mut bv).unwrap();
163
164        let n = bv.count_ones();
165        if n > 1 && n < n_taxa {
166            if let Some(c) = counts.get(&bv.as_raw_slice().to_vec()) {
167                node.label = Some(NameOrSupport::Support(*c));
168            }
169        }
170
171        // recursion
172        if let Some([l, r]) = &mut node.children {
173            add_bootstrap_to_tree(l, idx, n_taxa, counts);
174            add_bootstrap_to_tree(r, idx, n_taxa, counts);
175        }
176    }
177}
178
179/// Heuristically detects whether the MSA contains DNA or protein sequences.
180///
181/// Returns [`Alphabet::DNA`] unless any sequence contains a byte that is not
182/// in `{A, C, G, T, U, N, -}` (case-insensitive), in which case
183/// [`Alphabet::Protein`] is returned. This covers all 20 standard amino acids
184/// since letters like `D`, `E`, `F`, `H`, `I`, `K`, `L`, `M`, `P`, `Q`,
185/// `R`, `S`, `V`, `W`, `Y` cannot appear in a DNA alignment.
186fn detect_alphabet(msa: &[SequenceObject]) -> Result<Alphabet, String> {
187    // Simple heuristic: if any character is > A,C,G,T,N, assume protein
188    let mut is_protein = false;
189
190    for seq in msa {
191        for c in seq.sequence.bytes() {
192            match c.to_ascii_uppercase() {
193                b'A' | b'C' | b'G' | b'T' | b'U' | b'N' | b'-' => { /* still possible DNA */ }
194                _ => {
195                    is_protein = true;
196                    break;
197                }
198            }
199        }
200        if is_protein {
201            break;
202        }
203    }
204
205    Ok(if is_protein {
206        Alphabet::Protein
207    } else {
208        Alphabet::DNA
209    })
210}
211
212/// Runs NJ with model `M` on alphabet `A` and returns a Newick string.
213///
214/// If `n_bootstrap_samples > 0`, generates that many bootstrap replicates,
215/// collects clade counts via [`bootstrap_clade_counts`], runs NJ on the
216/// original distances, and annotates the tree before serialising to Newick.
217fn run_nj<A, M>(msa: MSA<A>, n_bootstrap_samples: usize) -> Result<String, String>
218where
219    A: AlphabetEncoding,
220    M: ModelCalculation<A>,
221{
222    // bootstrap_clade_counts should be generic over A,M too (not shown here)
223    let clade_counts = bootstrap_clade_counts::<A, M>(&msa, n_bootstrap_samples)?;
224
225    let mut main_tree = msa.into_dist::<M>().neighbor_joining()?;
226    let newick = match clade_counts {
227        Some(counts) => {
228            let main_idx_map: HashMap<String, usize> = msa.to_index_map();
229            add_bootstrap_to_tree(&mut main_tree, &main_idx_map, msa.n_sequences, &counts);
230            main_tree.to_newick()
231        }
232        None => main_tree.to_newick(),
233    };
234    Ok(newick)
235}
236
237/// Infers a phylogenetic tree from an aligned MSA and returns a Newick string.
238///
239/// This is the single public entry point for the library. The alphabet is
240/// auto-detected from the sequences; `conf.substitution_model` must be
241/// compatible with the detected alphabet (see the module-level compatibility
242/// table). Returns `Err` for an empty MSA, an incompatible model, or any
243/// internal NJ failure.
244pub fn nj(conf: NJConfig) -> Result<String, String> {
245    if conf.msa.is_empty() {
246        return Err("Input MSA is empty".into());
247    }
248    let alphabet = detect_alphabet(&conf.msa)?;
249    match alphabet {
250        Alphabet::DNA => {
251            // build MSA specialized to DNA (pre-encodes using DNA::encode)
252            let msa =
253                MSA::<DNA>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
254
255            match conf.substitution_model {
256                SubstitutionModel::PDiff => run_nj::<DNA, PDiff>(msa, conf.n_bootstrap_samples),
257                SubstitutionModel::JukesCantor => {
258                    run_nj::<DNA, JukesCantor>(msa, conf.n_bootstrap_samples)
259                }
260                SubstitutionModel::Kimura2P => {
261                    run_nj::<DNA, Kimura2P>(msa, conf.n_bootstrap_samples)
262                }
263                // Poisson is a protein model — either disallow here or handle by error:
264                SubstitutionModel::Poisson => {
265                    Err("Poisson is a protein model; cannot use with DNA".into())
266                }
267            }
268        }
269
270        Alphabet::Protein => {
271            let msa =
272                MSA::<Protein>::from_iter(conf.msa.into_iter().map(|s| (s.identifier, s.sequence)));
273
274            match conf.substitution_model {
275                // Poisson is valid for protein:
276                SubstitutionModel::Poisson => {
277                    run_nj::<Protein, Poisson>(msa, conf.n_bootstrap_samples)
278                }
279                SubstitutionModel::PDiff => run_nj::<Protein, PDiff>(msa, conf.n_bootstrap_samples),
280                // DNA-only models should be rejected for proteins:
281                SubstitutionModel::JukesCantor | SubstitutionModel::Kimura2P => {
282                    Err("Selected model is for DNA; cannot use with Protein".into())
283                }
284            }
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::models::SubstitutionModel;
293
294    #[test]
295    fn test_nj_wrapper_simple_tree() {
296        let sequences = vec![
297            SequenceObject {
298                identifier: "A".into(),
299                sequence: "ACGTCG".into(),
300            },
301            SequenceObject {
302                identifier: "B".into(),
303                sequence: "ACG-GC".into(),
304            },
305        ];
306        let conf = NJConfig {
307            msa: sequences,
308            n_bootstrap_samples: 0,
309            substitution_model: SubstitutionModel::PDiff,
310        };
311        let newick = nj(conf).expect("NJ failed");
312        assert_eq!(newick, "(A:0.167,B:0.167);");
313    }
314
315    #[test]
316    fn test_nj_wrapper_adds_semicolon() {
317        let sequences = vec![
318            SequenceObject {
319                identifier: "Seq0".into(),
320                sequence: "A".into(),
321            },
322            SequenceObject {
323                identifier: "Seq1".into(),
324                sequence: "A".into(),
325            },
326        ];
327        let conf = NJConfig {
328            msa: sequences,
329            n_bootstrap_samples: 0,
330            substitution_model: SubstitutionModel::PDiff,
331        };
332        let out = nj(conf).unwrap();
333        assert!(out.ends_with(';'));
334    }
335
336    #[test]
337    fn test_nj_deterministic_order() {
338        let sequences = vec![
339            SequenceObject {
340                identifier: "Seq0".into(),
341                sequence: "ACGTCG".into(),
342            },
343            SequenceObject {
344                identifier: "Seq1".into(),
345                sequence: "ACG-GC".into(),
346            },
347            SequenceObject {
348                identifier: "Seq2".into(),
349                sequence: "ACGCGT".into(),
350            },
351        ];
352        let conf = NJConfig {
353            msa: sequences,
354            n_bootstrap_samples: 0,
355            substitution_model: SubstitutionModel::PDiff,
356        };
357
358        let t1 = nj(conf.clone()).unwrap();
359        let t2 = nj(conf).unwrap();
360        assert_eq!(t1, t2);
361    }
362
363    #[test]
364    fn test_nj_wrapper_empty_msa() {
365        let conf = NJConfig {
366            msa: vec![],
367            n_bootstrap_samples: 0,
368            substitution_model: SubstitutionModel::PDiff,
369        };
370        let result = nj(conf);
371        assert!(result.is_err());
372    }
373
374    #[test]
375    fn test_nj_wrapper_incorrect_model_for_alphabet() {
376        let sequences = vec![
377            SequenceObject {
378                identifier: "Seq0".into(),
379                sequence: "ACGTCG".into(),
380            },
381            SequenceObject {
382                identifier: "Seq1".into(),
383                sequence: "ACG-GC".into(),
384            },
385        ];
386        let conf = NJConfig {
387            msa: sequences,
388            n_bootstrap_samples: 0,
389            substitution_model: SubstitutionModel::Poisson, // protein model for DNA MSA
390        };
391        let result = nj(conf);
392        assert!(result.is_err());
393    }
394
395    #[test]
396    fn test_nj_wrapper_incorrect_model_for_protein() {
397        let sequences = vec![
398            SequenceObject {
399                identifier: "Seq0".into(),
400                sequence: "ACDEFGH".into(),
401            },
402            SequenceObject {
403                identifier: "Seq1".into(),
404                sequence: "ACD-FGH".into(),
405            },
406        ];
407        let conf = NJConfig {
408            msa: sequences,
409            n_bootstrap_samples: 0,
410            substitution_model: SubstitutionModel::JukesCantor, // DNA model for protein MSA
411        };
412        let result = nj(conf);
413        assert!(result.is_err());
414    }
415
416    #[test]
417    fn test_detect_alphabet_dna() {
418        let msa = vec![
419            SequenceObject {
420                identifier: "Seq0".into(),
421                sequence: "ACGTACGT".into(),
422            },
423            SequenceObject {
424                identifier: "Seq1".into(),
425                sequence: "ACG-ACGT".into(),
426            },
427        ];
428        let alphabet = detect_alphabet(&msa).expect("detection failed");
429        assert_eq!(alphabet, Alphabet::DNA);
430    }
431
432    #[test]
433    fn test_detect_alphabet_protein() {
434        let msa = vec![
435            SequenceObject {
436                identifier: "Seq0".into(),
437                sequence: "ACDEFGHIK".into(),
438            },
439            SequenceObject {
440                identifier: "Seq1".into(),
441                sequence: "ACD-FGHIK".into(),
442            },
443        ];
444        let alphabet = detect_alphabet(&msa).expect("detection failed");
445        assert_eq!(alphabet, Alphabet::Protein);
446    }
447}