Skip to main content

cyanea_seq/
taxonomy.rs

1//! Taxonomic classification — taxonomy trees, LCA queries, k-mer classifiers.
2//!
3//! Build a taxonomy tree, compute lowest common ancestors, and classify
4//! sequences using a Kraken-style k-mer approach.
5
6use std::collections::BTreeMap;
7
8use cyanea_core::{CyaneaError, Result};
9
10/// Taxonomic rank in the NCBI hierarchy.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum TaxonRank {
13    Domain,
14    Phylum,
15    Class,
16    Order,
17    Family,
18    Genus,
19    Species,
20    Unranked,
21}
22
23/// A node in the taxonomy tree.
24#[derive(Debug, Clone)]
25pub struct TaxonomyNode {
26    /// Unique node identifier (index in the tree's node vector).
27    pub id: usize,
28    /// Taxon name.
29    pub name: String,
30    /// Taxonomic rank.
31    pub rank: TaxonRank,
32    /// Parent node id, or `None` for the root.
33    pub parent: Option<usize>,
34}
35
36/// A rooted taxonomy tree.
37///
38/// Nodes are stored in a flat vector; parent links encode the tree structure.
39#[derive(Debug, Clone)]
40pub struct TaxonomyTree {
41    nodes: Vec<TaxonomyNode>,
42}
43
44impl TaxonomyTree {
45    /// Create a new empty taxonomy tree.
46    pub fn new() -> Self {
47        Self { nodes: Vec::new() }
48    }
49
50    /// Add a node to the tree. Returns the node's id.
51    ///
52    /// The root node should have `parent = None`. All other nodes must
53    /// reference a valid parent id.
54    pub fn add_node(&mut self, name: &str, rank: TaxonRank, parent: Option<usize>) -> usize {
55        let id = self.nodes.len();
56        self.nodes.push(TaxonomyNode {
57            id,
58            name: name.to_string(),
59            rank,
60            parent,
61        });
62        id
63    }
64
65    /// Compute the lowest common ancestor of a set of node ids.
66    ///
67    /// Returns the deepest node that is an ancestor of all input nodes.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if `ids` is empty or any id is out of range.
72    pub fn lca(&self, ids: &[usize]) -> Result<usize> {
73        if ids.is_empty() {
74            return Err(CyaneaError::InvalidInput(
75                "at least one taxon id is required for LCA".into(),
76            ));
77        }
78        for &id in ids {
79            if id >= self.nodes.len() {
80                return Err(CyaneaError::InvalidInput(format!(
81                    "taxon id {} is out of range (tree has {} nodes)",
82                    id,
83                    self.nodes.len()
84                )));
85            }
86        }
87
88        // Get ancestor set for the first node.
89        let mut common_ancestors = self.ancestor_set(ids[0]);
90
91        // Intersect with ancestor sets of remaining nodes.
92        for &id in &ids[1..] {
93            let ancestors = self.ancestor_set(id);
94            common_ancestors.retain(|a| ancestors.contains(a));
95        }
96
97        // The LCA is the deepest (greatest depth) common ancestor.
98        common_ancestors
99            .into_iter()
100            .max_by_key(|&a| self.depth(a))
101            .ok_or_else(|| {
102                CyaneaError::InvalidInput("no common ancestor found".into())
103            })
104    }
105
106    /// Get the full lineage (path to root) for a node, ordered from the node to root.
107    pub fn lineage(&self, id: usize) -> Vec<usize> {
108        let mut path = Vec::new();
109        let mut current = id;
110        loop {
111            if current >= self.nodes.len() {
112                break;
113            }
114            path.push(current);
115            match self.nodes[current].parent {
116                Some(p) => current = p,
117                None => break,
118            }
119        }
120        path
121    }
122
123    /// Depth of a node (root = 0).
124    pub fn depth(&self, id: usize) -> usize {
125        self.lineage(id).len().saturating_sub(1)
126    }
127
128    /// Get the set of all ancestors of a node (including itself).
129    fn ancestor_set(&self, id: usize) -> Vec<usize> {
130        self.lineage(id)
131    }
132}
133
134impl Default for TaxonomyTree {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140/// K-mer based taxonomic classifier (Kraken-style).
141///
142/// Maps k-mer hashes to sets of taxon ids. Classification of a query
143/// sequence takes the LCA of all k-mer hits.
144#[derive(Debug, Clone)]
145pub struct KmerClassifier {
146    /// k-mer hash → taxon ids that contain this k-mer.
147    db: BTreeMap<u64, Vec<usize>>,
148    /// The taxonomy tree for LCA computation.
149    taxonomy: TaxonomyTree,
150    /// k-mer length.
151    k: usize,
152}
153
154impl KmerClassifier {
155    /// Create a new classifier with the given taxonomy and k-mer size.
156    pub fn new(taxonomy: TaxonomyTree, k: usize) -> Self {
157        Self {
158            db: BTreeMap::new(),
159            taxonomy,
160            k,
161        }
162    }
163
164    /// Index a reference sequence under the given taxon id.
165    ///
166    /// Extracts all k-mers, hashes them, and associates them with `taxon_id`.
167    pub fn add_reference(&mut self, sequence: &[u8], taxon_id: usize) {
168        if sequence.len() < self.k {
169            return;
170        }
171        let upper: Vec<u8> = sequence.iter().map(|b| b.to_ascii_uppercase()).collect();
172        for window in upper.windows(self.k) {
173            let hash = hash_kmer(window);
174            let entry = self.db.entry(hash).or_default();
175            if !entry.contains(&taxon_id) {
176                entry.push(taxon_id);
177            }
178        }
179    }
180
181    /// Classify a query sequence by k-mer matching.
182    ///
183    /// For each k-mer in the query, looks up matching taxa. Returns the LCA
184    /// of all matched taxa, preferring deeper (more specific) assignments.
185    /// Returns `None` if no k-mers match.
186    pub fn classify(&self, sequence: &[u8]) -> Option<usize> {
187        if sequence.len() < self.k {
188            return None;
189        }
190
191        let upper: Vec<u8> = sequence.iter().map(|b| b.to_ascii_uppercase()).collect();
192        let mut hit_taxa: Vec<usize> = Vec::new();
193
194        for window in upper.windows(self.k) {
195            let hash = hash_kmer(window);
196            if let Some(taxa) = self.db.get(&hash) {
197                // For each k-mer, compute LCA of its taxa and use that.
198                if let Ok(kmer_lca) = self.taxonomy.lca(taxa) {
199                    hit_taxa.push(kmer_lca);
200                }
201            }
202        }
203
204        if hit_taxa.is_empty() {
205            return None;
206        }
207
208        // Final classification: LCA of all per-kmer assignments.
209        self.taxonomy.lca(&hit_taxa).ok()
210    }
211}
212
213/// Simple 2-bit hash for DNA k-mers.
214fn hash_kmer(kmer: &[u8]) -> u64 {
215    let mut h: u64 = 0;
216    for &b in kmer {
217        h = h.wrapping_mul(4);
218        h = h.wrapping_add(match b {
219            b'A' | b'a' => 0,
220            b'C' | b'c' => 1,
221            b'G' | b'g' => 2,
222            b'T' | b't' => 3,
223            _ => 0,
224        });
225    }
226    h
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    fn sample_tree() -> TaxonomyTree {
234        let mut tree = TaxonomyTree::new();
235        // 0: root
236        tree.add_node("root", TaxonRank::Unranked, None);
237        // 1: Bacteria
238        tree.add_node("Bacteria", TaxonRank::Domain, Some(0));
239        // 2: Proteobacteria
240        tree.add_node("Proteobacteria", TaxonRank::Phylum, Some(1));
241        // 3: Firmicutes
242        tree.add_node("Firmicutes", TaxonRank::Phylum, Some(1));
243        // 4: E. coli (under Proteobacteria)
244        tree.add_node("E. coli", TaxonRank::Species, Some(2));
245        // 5: Salmonella (under Proteobacteria)
246        tree.add_node("Salmonella", TaxonRank::Species, Some(2));
247        // 6: B. subtilis (under Firmicutes)
248        tree.add_node("B. subtilis", TaxonRank::Species, Some(3));
249        tree
250    }
251
252    #[test]
253    fn taxonomy_tree_construction() {
254        let tree = sample_tree();
255        assert_eq!(tree.nodes.len(), 7);
256        assert_eq!(tree.depth(0), 0); // root
257        assert_eq!(tree.depth(4), 3); // E. coli: root → Bacteria → Proteo → E. coli
258    }
259
260    #[test]
261    fn lca_sibling_species() {
262        let tree = sample_tree();
263        // LCA(E. coli, Salmonella) = Proteobacteria
264        let lca = tree.lca(&[4, 5]).unwrap();
265        assert_eq!(lca, 2);
266    }
267
268    #[test]
269    fn lca_same_node() {
270        let tree = sample_tree();
271        let lca = tree.lca(&[4, 4]).unwrap();
272        assert_eq!(lca, 4);
273    }
274
275    #[test]
276    fn classify_exact_match() {
277        let tree = sample_tree();
278        let mut classifier = KmerClassifier::new(tree, 4);
279        // Add a reference for E. coli (taxon 4).
280        classifier.add_reference(b"ACGTACGTACGT", 4);
281        // Query with the same sequence should classify to E. coli.
282        let result = classifier.classify(b"ACGTACGTACGT");
283        assert_eq!(result, Some(4));
284    }
285
286    #[test]
287    fn classify_ambiguous_lca() {
288        let tree = sample_tree();
289        let mut classifier = KmerClassifier::new(tree, 4);
290        // Add overlapping references for two species.
291        classifier.add_reference(b"ACGTACGT", 4); // E. coli
292        classifier.add_reference(b"ACGTACGT", 5); // Salmonella
293        // Query with shared sequence should resolve to LCA = Proteobacteria.
294        let result = classifier.classify(b"ACGTACGT");
295        assert_eq!(result, Some(2));
296    }
297}