1use std::collections::BTreeMap;
7
8use cyanea_core::{CyaneaError, Result};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub enum TaxonRank {
13 Domain,
14 Phylum,
15 Class,
16 Order,
17 Family,
18 Genus,
19 Species,
20 Unranked,
21}
22
23#[derive(Debug, Clone)]
25pub struct TaxonomyNode {
26 pub id: usize,
28 pub name: String,
30 pub rank: TaxonRank,
32 pub parent: Option<usize>,
34}
35
36#[derive(Debug, Clone)]
40pub struct TaxonomyTree {
41 nodes: Vec<TaxonomyNode>,
42}
43
44impl TaxonomyTree {
45 pub fn new() -> Self {
47 Self { nodes: Vec::new() }
48 }
49
50 pub fn add_node(&mut self, name: &str, rank: TaxonRank, parent: Option<usize>) -> usize {
55 let id = self.nodes.len();
56 self.nodes.push(TaxonomyNode {
57 id,
58 name: name.to_string(),
59 rank,
60 parent,
61 });
62 id
63 }
64
65 pub fn lca(&self, ids: &[usize]) -> Result<usize> {
73 if ids.is_empty() {
74 return Err(CyaneaError::InvalidInput(
75 "at least one taxon id is required for LCA".into(),
76 ));
77 }
78 for &id in ids {
79 if id >= self.nodes.len() {
80 return Err(CyaneaError::InvalidInput(format!(
81 "taxon id {} is out of range (tree has {} nodes)",
82 id,
83 self.nodes.len()
84 )));
85 }
86 }
87
88 let mut common_ancestors = self.ancestor_set(ids[0]);
90
91 for &id in &ids[1..] {
93 let ancestors = self.ancestor_set(id);
94 common_ancestors.retain(|a| ancestors.contains(a));
95 }
96
97 common_ancestors
99 .into_iter()
100 .max_by_key(|&a| self.depth(a))
101 .ok_or_else(|| {
102 CyaneaError::InvalidInput("no common ancestor found".into())
103 })
104 }
105
106 pub fn lineage(&self, id: usize) -> Vec<usize> {
108 let mut path = Vec::new();
109 let mut current = id;
110 loop {
111 if current >= self.nodes.len() {
112 break;
113 }
114 path.push(current);
115 match self.nodes[current].parent {
116 Some(p) => current = p,
117 None => break,
118 }
119 }
120 path
121 }
122
123 pub fn depth(&self, id: usize) -> usize {
125 self.lineage(id).len().saturating_sub(1)
126 }
127
128 fn ancestor_set(&self, id: usize) -> Vec<usize> {
130 self.lineage(id)
131 }
132}
133
134impl Default for TaxonomyTree {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140#[derive(Debug, Clone)]
145pub struct KmerClassifier {
146 db: BTreeMap<u64, Vec<usize>>,
148 taxonomy: TaxonomyTree,
150 k: usize,
152}
153
154impl KmerClassifier {
155 pub fn new(taxonomy: TaxonomyTree, k: usize) -> Self {
157 Self {
158 db: BTreeMap::new(),
159 taxonomy,
160 k,
161 }
162 }
163
164 pub fn add_reference(&mut self, sequence: &[u8], taxon_id: usize) {
168 if sequence.len() < self.k {
169 return;
170 }
171 let upper: Vec<u8> = sequence.iter().map(|b| b.to_ascii_uppercase()).collect();
172 for window in upper.windows(self.k) {
173 let hash = hash_kmer(window);
174 let entry = self.db.entry(hash).or_default();
175 if !entry.contains(&taxon_id) {
176 entry.push(taxon_id);
177 }
178 }
179 }
180
181 pub fn classify(&self, sequence: &[u8]) -> Option<usize> {
187 if sequence.len() < self.k {
188 return None;
189 }
190
191 let upper: Vec<u8> = sequence.iter().map(|b| b.to_ascii_uppercase()).collect();
192 let mut hit_taxa: Vec<usize> = Vec::new();
193
194 for window in upper.windows(self.k) {
195 let hash = hash_kmer(window);
196 if let Some(taxa) = self.db.get(&hash) {
197 if let Ok(kmer_lca) = self.taxonomy.lca(taxa) {
199 hit_taxa.push(kmer_lca);
200 }
201 }
202 }
203
204 if hit_taxa.is_empty() {
205 return None;
206 }
207
208 self.taxonomy.lca(&hit_taxa).ok()
210 }
211}
212
213fn hash_kmer(kmer: &[u8]) -> u64 {
215 let mut h: u64 = 0;
216 for &b in kmer {
217 h = h.wrapping_mul(4);
218 h = h.wrapping_add(match b {
219 b'A' | b'a' => 0,
220 b'C' | b'c' => 1,
221 b'G' | b'g' => 2,
222 b'T' | b't' => 3,
223 _ => 0,
224 });
225 }
226 h
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 fn sample_tree() -> TaxonomyTree {
234 let mut tree = TaxonomyTree::new();
235 tree.add_node("root", TaxonRank::Unranked, None);
237 tree.add_node("Bacteria", TaxonRank::Domain, Some(0));
239 tree.add_node("Proteobacteria", TaxonRank::Phylum, Some(1));
241 tree.add_node("Firmicutes", TaxonRank::Phylum, Some(1));
243 tree.add_node("E. coli", TaxonRank::Species, Some(2));
245 tree.add_node("Salmonella", TaxonRank::Species, Some(2));
247 tree.add_node("B. subtilis", TaxonRank::Species, Some(3));
249 tree
250 }
251
252 #[test]
253 fn taxonomy_tree_construction() {
254 let tree = sample_tree();
255 assert_eq!(tree.nodes.len(), 7);
256 assert_eq!(tree.depth(0), 0); assert_eq!(tree.depth(4), 3); }
259
260 #[test]
261 fn lca_sibling_species() {
262 let tree = sample_tree();
263 let lca = tree.lca(&[4, 5]).unwrap();
265 assert_eq!(lca, 2);
266 }
267
268 #[test]
269 fn lca_same_node() {
270 let tree = sample_tree();
271 let lca = tree.lca(&[4, 4]).unwrap();
272 assert_eq!(lca, 4);
273 }
274
275 #[test]
276 fn classify_exact_match() {
277 let tree = sample_tree();
278 let mut classifier = KmerClassifier::new(tree, 4);
279 classifier.add_reference(b"ACGTACGTACGT", 4);
281 let result = classifier.classify(b"ACGTACGTACGT");
283 assert_eq!(result, Some(4));
284 }
285
286 #[test]
287 fn classify_ambiguous_lca() {
288 let tree = sample_tree();
289 let mut classifier = KmerClassifier::new(tree, 4);
290 classifier.add_reference(b"ACGTACGT", 4); classifier.add_reference(b"ACGTACGT", 5); let result = classifier.classify(b"ACGTACGT");
295 assert_eq!(result, Some(2));
296 }
297}