Skip to main content

cyanea_seq/
debruijn.rs

1//! De Bruijn graph construction and unitig extraction.
2//!
3//! Build a De Bruijn graph from sequences or raw k-mers, then extract
4//! unitigs (maximal non-branching paths) with average coverage.
5
6use std::collections::BTreeMap;
7
8use cyanea_core::{CyaneaError, Result};
9
10/// A node-centric De Bruijn graph built from k-mers.
11///
12/// Each k-mer is an edge from its (k-1)-prefix to its (k-1)-suffix.
13/// Nodes are the (k-1)-mers; edges are the original k-mers.
14#[derive(Debug, Clone)]
15pub struct DeBruijnGraph {
16    k: usize,
17    /// Adjacency: prefix (k-1)-mer → list of suffix (k-1)-mers.
18    edges: BTreeMap<Vec<u8>, Vec<Vec<u8>>>,
19    /// Count of each k-mer seen during construction.
20    kmer_counts: BTreeMap<Vec<u8>, usize>,
21}
22
23/// A unitig: a maximal non-branching path through the De Bruijn graph.
24#[derive(Debug, Clone)]
25pub struct Unitig {
26    /// The assembled sequence of the unitig.
27    pub sequence: Vec<u8>,
28    /// Mean k-mer coverage along the unitig.
29    pub coverage: f64,
30}
31
32impl DeBruijnGraph {
33    /// Build a De Bruijn graph from a set of input sequences.
34    ///
35    /// Extracts all k-mers from each sequence and adds them as edges.
36    /// Bases are uppercased; only A/C/G/T bases are accepted.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if `k < 2`, any sequence is shorter than `k`,
41    /// or a sequence contains non-ACGT characters.
42    pub fn from_sequences(sequences: &[&[u8]], k: usize) -> Result<Self> {
43        if k < 2 {
44            return Err(CyaneaError::InvalidInput(
45                "k must be at least 2 for De Bruijn graph construction".into(),
46            ));
47        }
48        let mut graph = Self {
49            k,
50            edges: BTreeMap::new(),
51            kmer_counts: BTreeMap::new(),
52        };
53        for seq in sequences {
54            if seq.len() < k {
55                return Err(CyaneaError::InvalidInput(format!(
56                    "sequence length {} is shorter than k={}",
57                    seq.len(),
58                    k
59                )));
60            }
61            let upper: Vec<u8> = seq.iter().map(|b| b.to_ascii_uppercase()).collect();
62            for b in &upper {
63                if !matches!(b, b'A' | b'C' | b'G' | b'T') {
64                    return Err(CyaneaError::InvalidInput(format!(
65                        "invalid base '{}' for De Bruijn graph",
66                        *b as char
67                    )));
68                }
69            }
70            for window in upper.windows(k) {
71                graph.add_kmer(window);
72            }
73        }
74        Ok(graph)
75    }
76
77    /// Build a De Bruijn graph from pre-extracted k-mers.
78    ///
79    /// All k-mers must have the same length (≥ 2).
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if the k-mer slice is empty or k-mers differ in length.
84    pub fn from_kmers(kmers: &[&[u8]]) -> Result<Self> {
85        if kmers.is_empty() {
86            return Err(CyaneaError::InvalidInput(
87                "at least one k-mer is required".into(),
88            ));
89        }
90        let k = kmers[0].len();
91        if k < 2 {
92            return Err(CyaneaError::InvalidInput(
93                "k-mers must have length at least 2".into(),
94            ));
95        }
96        for kmer in kmers {
97            if kmer.len() != k {
98                return Err(CyaneaError::InvalidInput(format!(
99                    "all k-mers must have the same length; expected {} but got {}",
100                    k,
101                    kmer.len()
102                )));
103            }
104        }
105        let mut graph = Self {
106            k,
107            edges: BTreeMap::new(),
108            kmer_counts: BTreeMap::new(),
109        };
110        for kmer in kmers {
111            let upper: Vec<u8> = kmer.iter().map(|b| b.to_ascii_uppercase()).collect();
112            graph.add_kmer(&upper);
113        }
114        Ok(graph)
115    }
116
117    fn add_kmer(&mut self, kmer: &[u8]) {
118        let prefix = kmer[..self.k - 1].to_vec();
119        let suffix = kmer[1..].to_vec();
120        self.edges.entry(prefix).or_default().push(suffix.clone());
121        // Ensure suffix node exists in the graph even if it has no outgoing edges.
122        self.edges.entry(suffix).or_default();
123        *self.kmer_counts.entry(kmer.to_vec()).or_insert(0) += 1;
124    }
125
126    /// Number of distinct (k-1)-mer nodes in the graph.
127    pub fn node_count(&self) -> usize {
128        self.edges.len()
129    }
130
131    /// Number of k-mer edges in the graph (including duplicates).
132    pub fn edge_count(&self) -> usize {
133        self.kmer_counts.len()
134    }
135
136    /// Check whether a specific k-mer exists as an edge in the graph.
137    pub fn contains_kmer(&self, kmer: &[u8]) -> bool {
138        let upper: Vec<u8> = kmer.iter().map(|b| b.to_ascii_uppercase()).collect();
139        self.kmer_counts.contains_key(&upper)
140    }
141
142    /// Extract all unitigs (maximal non-branching paths).
143    ///
144    /// A unitig is a path where every internal node has exactly one incoming
145    /// and one outgoing edge. Coverage is the mean count of constituent k-mers.
146    pub fn unitigs(&self) -> Vec<Unitig> {
147        // Compute in-degree for each node.
148        let mut in_degree: BTreeMap<&Vec<u8>, usize> = BTreeMap::new();
149        for (_, successors) in &self.edges {
150            for s in successors {
151                *in_degree.entry(s).or_insert(0) += 1;
152            }
153        }
154
155        let out_degree = |node: &Vec<u8>| -> usize {
156            self.edges.get(node).map_or(0, |v| v.len())
157        };
158        let in_deg = |node: &Vec<u8>| -> usize { in_degree.get(node).copied().unwrap_or(0) };
159
160        // A node is a branching point if in-degree != 1 or out-degree != 1.
161        let is_start = |node: &Vec<u8>| -> bool {
162            in_deg(node) != 1 || out_degree(node) != 1
163        };
164
165        let mut visited: BTreeMap<(Vec<u8>, Vec<u8>), bool> = BTreeMap::new();
166        let mut unitigs = Vec::new();
167
168        // Start unitig walks from branching nodes.
169        let nodes: Vec<Vec<u8>> = self.edges.keys().cloned().collect();
170        for node in &nodes {
171            if !is_start(node) {
172                continue;
173            }
174            let successors = match self.edges.get(node) {
175                Some(s) => s.clone(),
176                None => continue,
177            };
178            for succ in &successors {
179                let edge_key = (node.clone(), succ.clone());
180                if visited.contains_key(&edge_key) {
181                    continue;
182                }
183                visited.insert(edge_key, true);
184
185                // Build the unitig sequence: start with the prefix node, then extend.
186                let mut path_nodes: Vec<Vec<u8>> = vec![node.clone(), succ.clone()];
187                let mut current = succ.clone();
188
189                // Extend forward while the path is non-branching.
190                while !is_start(&current) {
191                    let next_successors = match self.edges.get(&current) {
192                        Some(s) if s.len() == 1 => s,
193                        _ => break,
194                    };
195                    let next = &next_successors[0];
196                    let edge_key = (current.clone(), next.clone());
197                    if visited.contains_key(&edge_key) {
198                        break;
199                    }
200                    visited.insert(edge_key, true);
201                    path_nodes.push(next.clone());
202                    current = next.clone();
203                }
204
205                // Build sequence: first (k-1)-mer, then one base per subsequent node.
206                let mut sequence = path_nodes[0].clone();
207                for p in &path_nodes[1..] {
208                    sequence.push(*p.last().unwrap());
209                }
210
211                // Compute mean coverage from constituent k-mers.
212                let k = self.k;
213                let mut total_count = 0usize;
214                let mut n_kmers = 0usize;
215                for window in sequence.windows(k) {
216                    if let Some(&c) = self.kmer_counts.get(window) {
217                        total_count += c;
218                        n_kmers += 1;
219                    }
220                }
221                let coverage = if n_kmers > 0 {
222                    total_count as f64 / n_kmers as f64
223                } else {
224                    0.0
225                };
226
227                unitigs.push(Unitig { sequence, coverage });
228            }
229        }
230
231        // Handle isolated cycles (all nodes have in=1, out=1).
232        // Find any unvisited edges.
233        for (node, successors) in &self.edges {
234            for succ in successors {
235                let edge_key = (node.clone(), succ.clone());
236                if visited.contains_key(&edge_key) {
237                    continue;
238                }
239                visited.insert(edge_key.clone(), true);
240
241                let mut path_nodes: Vec<Vec<u8>> = vec![node.clone(), succ.clone()];
242                let mut current = succ.clone();
243
244                loop {
245                    let next_successors = match self.edges.get(&current) {
246                        Some(s) if s.len() == 1 => s,
247                        _ => break,
248                    };
249                    let next = &next_successors[0];
250                    let ek = (current.clone(), next.clone());
251                    if visited.contains_key(&ek) {
252                        break;
253                    }
254                    visited.insert(ek, true);
255                    path_nodes.push(next.clone());
256                    current = next.clone();
257                }
258
259                let mut sequence = path_nodes[0].clone();
260                for p in &path_nodes[1..] {
261                    sequence.push(*p.last().unwrap());
262                }
263
264                let k = self.k;
265                let mut total_count = 0usize;
266                let mut n_kmers = 0usize;
267                for window in sequence.windows(k) {
268                    if let Some(&c) = self.kmer_counts.get(window) {
269                        total_count += c;
270                        n_kmers += 1;
271                    }
272                }
273                let coverage = if n_kmers > 0 {
274                    total_count as f64 / n_kmers as f64
275                } else {
276                    0.0
277                };
278
279                unitigs.push(Unitig { sequence, coverage });
280            }
281        }
282
283        unitigs
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn debruijn_simple_sequence() {
293        // ACGTACGT with k=3 should build a valid graph.
294        let graph = DeBruijnGraph::from_sequences(&[b"ACGTACGT"], 3).unwrap();
295        assert!(graph.node_count() > 0);
296        assert!(graph.edge_count() > 0);
297    }
298
299    #[test]
300    fn debruijn_node_edge_counts() {
301        // ACGT with k=3: k-mers are ACG, CGT → 2 edges.
302        // Nodes (2-mers): AC, CG, GT → 3 nodes.
303        let graph = DeBruijnGraph::from_sequences(&[b"ACGT"], 3).unwrap();
304        assert_eq!(graph.edge_count(), 2);
305        assert_eq!(graph.node_count(), 3);
306    }
307
308    #[test]
309    fn debruijn_contains_kmer() {
310        let graph = DeBruijnGraph::from_sequences(&[b"ACGTACGT"], 3).unwrap();
311        assert!(graph.contains_kmer(b"ACG"));
312        assert!(graph.contains_kmer(b"CGT"));
313        assert!(!graph.contains_kmer(b"AAA"));
314    }
315
316    #[test]
317    fn unitig_extraction_simple() {
318        // Linear sequence with no branching should yield unitig(s).
319        let graph = DeBruijnGraph::from_sequences(&[b"ACGTACGT"], 3).unwrap();
320        let unitigs = graph.unitigs();
321        assert!(!unitigs.is_empty());
322        // At least one unitig should contain our original k-mers.
323        let total_len: usize = unitigs.iter().map(|u| u.sequence.len()).sum();
324        assert!(total_len >= 3); // at least one k-mer length
325    }
326
327    #[test]
328    fn unitig_coverage_correct() {
329        // Feed the same sequence twice — coverage should be 2.0.
330        let graph =
331            DeBruijnGraph::from_sequences(&[b"ACGT", b"ACGT"], 3).unwrap();
332        let unitigs = graph.unitigs();
333        assert!(!unitigs.is_empty());
334        for u in &unitigs {
335            assert!((u.coverage - 2.0).abs() < 1e-10);
336        }
337    }
338}