Skip to main content

argenus/
extender.rs

1
2use anyhow::Result;
3use rayon::prelude::*;
4use rustc_hash::FxHashMap;
5use std::fs::File;
6use std::io::{BufWriter, Write};
7use std::path::Path;
8use std::sync::Mutex;
9
10use crate::seqio::{FastaRecord, FastqFile};
11
12#[derive(Clone)]
13pub struct ExtenderConfig {
14
15    pub kmer_size: usize,
16
17    pub num_edge_kmers: usize,
18
19    pub min_coverage: usize,
20
21    pub branching_threshold: f64,
22
23    pub max_n_ratio: f64,
24
25    pub extension_step: usize,
26
27    pub max_consecutive_failures: usize,
28}
29
30impl Default for ExtenderConfig {
31    fn default() -> Self {
32        Self {
33            kmer_size: 21,
34            num_edge_kmers: 5,
35            min_coverage: 2,
36            branching_threshold: 0.2,
37            max_n_ratio: 0.05,
38            extension_step: 200,
39            max_consecutive_failures: 2,
40        }
41    }
42}
43
44#[derive(Debug, Clone)]
45pub struct ExtendedContig {
46
47    pub name: String,
48
49    pub extended_seq: String,
50}
51
52pub struct ContigExtender {
53    config: ExtenderConfig,
54    reads: Vec<String>,
55}
56
57impl ContigExtender {
58
59    pub fn new(config: ExtenderConfig) -> Self {
60        Self {
61            config,
62            reads: Vec::new(),
63        }
64    }
65
66    pub fn load_reads(&mut self, r1_path: &Path, r2_path: &Path) -> Result<()> {
67        eprintln!("Loading reads into memory...");
68
69        let r1_owned = r1_path.to_path_buf();
70        let r2_owned = r2_path.to_path_buf();
71
72        let handle_r1 = std::thread::spawn(move || -> Result<Vec<String>> {
73            let mut reads = Vec::new();
74            let mut reader = FastqFile::open(&r1_owned)?;
75            while let Some(record) = reader.read_next()? {
76                reads.push(record.seq);
77            }
78            Ok(reads)
79        });
80
81        let handle_r2 = std::thread::spawn(move || -> Result<Vec<String>> {
82            let mut reads = Vec::new();
83            let mut reader = FastqFile::open(&r2_owned)?;
84            while let Some(record) = reader.read_next()? {
85                reads.push(record.seq);
86            }
87            Ok(reads)
88        });
89
90        let reads_r1 = handle_r1.join().map_err(|_| anyhow::anyhow!("R1 load thread panicked"))??;
91        let reads_r2 = handle_r2.join().map_err(|_| anyhow::anyhow!("R2 load thread panicked"))??;
92
93        self.reads = reads_r1;
94        self.reads.extend(reads_r2);
95
96        eprintln!("Loaded {} reads into memory", self.reads.len());
97        Ok(())
98    }
99
100    pub fn extend_contigs(&self, contigs: &[FastaRecord]) -> Result<Vec<ExtendedContig>> {
101        let k = self.config.kmer_size;
102        let max_failures = self.config.max_consecutive_failures;
103
104        let states: Vec<Mutex<ContigState>> = contigs.iter().map(|c| {
105            Mutex::new(ContigState {
106                name: c.name.clone(),
107                current_seq: c.seq.clone(),
108                left_failures: 0,
109                right_failures: 0,
110            })
111        }).collect();
112
113        loop {
114
115            let active_indices: Vec<usize> = states.iter().enumerate()
116                .filter(|(_, s)| {
117                    let s = s.lock().unwrap();
118                    s.left_failures < max_failures || s.right_failures < max_failures
119                })
120                .map(|(i, _)| i)
121                .collect();
122
123            if active_indices.is_empty() {
124                break;
125            }
126
127            let mut edge_kmers: FxHashMap<u64, Vec<(usize, bool, usize)>> = FxHashMap::default();
128
129            for &idx in &active_indices {
130                let state = states[idx].lock().unwrap();
131                let seq = &state.current_seq;
132                if seq.len() < k {
133                    continue;
134                }
135
136                if state.left_failures < max_failures {
137                    for offset in 0..self.config.num_edge_kmers.min(seq.len() - k + 1) {
138                        if let Some(hash) = compute_kmer_hash(&seq[offset..offset+k]) {
139                            edge_kmers.entry(hash).or_default().push((idx, true, offset));
140                        }
141                    }
142                }
143
144                if state.right_failures < max_failures {
145                    let seq_len = seq.len();
146                    for offset in 0..self.config.num_edge_kmers.min(seq.len() - k + 1) {
147                        let start = seq_len - k - offset;
148                        if let Some(hash) = compute_kmer_hash(&seq[start..start+k]) {
149                            edge_kmers.entry(hash).or_default().push((idx, false, offset));
150                        }
151                    }
152                }
153            }
154
155            let left_candidates: Mutex<FxHashMap<usize, Vec<String>>> = Mutex::new(FxHashMap::default());
156            let right_candidates: Mutex<FxHashMap<usize, Vec<String>>> = Mutex::new(FxHashMap::default());
157
158            self.reads.par_iter().for_each(|read_seq| {
159                if read_seq.len() < k {
160                    return;
161                }
162
163                let mut local_left: FxHashMap<usize, Vec<String>> = FxHashMap::default();
164                let mut local_right: FxHashMap<usize, Vec<String>> = FxHashMap::default();
165
166                for i in 0..=(read_seq.len() - k) {
167                    let kmer_seq = &read_seq[i..i+k];
168                    if let Some(hash) = compute_kmer_hash(kmer_seq) {
169                        if let Some(matches) = edge_kmers.get(&hash) {
170                            for &(contig_idx, is_left, edge_offset) in matches {
171                                let state = states[contig_idx].lock().unwrap();
172                                let contig_kmer = if is_left {
173                                    &state.current_seq[edge_offset..edge_offset+k]
174                                } else {
175                                    let clen = state.current_seq.len();
176                                    &state.current_seq[clen-k-edge_offset..clen-edge_offset]
177                                };
178
179                                let (is_forward, is_revcomp) = check_kmer_match(kmer_seq, contig_kmer);
180                                drop(state);
181
182                                if is_left {
183                                    if is_forward && i > edge_offset {
184                                        let prefix = &read_seq[..i - edge_offset];
185                                        if !prefix.is_empty() {
186                                            let ext: String = prefix.chars().rev().collect();
187                                            local_left.entry(contig_idx).or_default().push(ext);
188                                        }
189                                    } else if is_revcomp && i + k + edge_offset < read_seq.len() {
190                                        let suffix = &read_seq[i+k+edge_offset..];
191                                        if !suffix.is_empty() {
192                                            let ext = reverse_complement(suffix);
193                                            local_left.entry(contig_idx).or_default().push(ext);
194                                        }
195                                    }
196                                } else if is_forward && i + k + edge_offset < read_seq.len() {
197                                    let suffix = &read_seq[i+k+edge_offset..];
198                                    if !suffix.is_empty() {
199                                        local_right.entry(contig_idx).or_default().push(suffix.to_string());
200                                    }
201                                } else if is_revcomp && i > edge_offset {
202                                    let prefix = &read_seq[..i - edge_offset];
203                                    if !prefix.is_empty() {
204                                        let ext = reverse_complement(prefix);
205                                        local_right.entry(contig_idx).or_default().push(ext);
206                                    }
207                                }
208                            }
209                        }
210                    }
211                }
212
213                if !local_left.is_empty() {
214                    let mut global = left_candidates.lock().unwrap();
215                    for (idx, candidates) in local_left {
216                        global.entry(idx).or_default().extend(candidates);
217                    }
218                }
219                if !local_right.is_empty() {
220                    let mut global = right_candidates.lock().unwrap();
221                    for (idx, candidates) in local_right {
222                        global.entry(idx).or_default().extend(candidates);
223                    }
224                }
225            });
226
227            let left_candidates = left_candidates.into_inner().unwrap();
228            let right_candidates = right_candidates.into_inner().unwrap();
229
230            let any_extended = std::sync::atomic::AtomicBool::new(false);
231
232            active_indices.par_iter().for_each(|&idx| {
233                let mut state = states[idx].lock().unwrap();
234
235                if state.left_failures < max_failures {
236                    if let Some(candidates) = left_candidates.get(&idx) {
237                        if candidates.len() >= self.config.min_coverage {
238                            let consensus = build_consensus_sequence(
239                                candidates,
240                                self.config.min_coverage,
241                                self.config.branching_threshold,
242                                self.config.extension_step,
243                            );
244                            if !consensus.is_empty() {
245                                let n_count = consensus.chars().filter(|&c| c == 'N').count();
246                                let n_ratio = n_count as f64 / consensus.len() as f64;
247
248                                if n_ratio <= self.config.max_n_ratio {
249                                    state.current_seq = format!("{}{}", consensus, state.current_seq);
250                                    state.left_failures = 0;
251                                    any_extended.store(true, std::sync::atomic::Ordering::Relaxed);
252                                } else {
253                                    state.left_failures += 1;
254                                }
255                            } else {
256                                state.left_failures += 1;
257                            }
258                        } else {
259                            state.left_failures += 1;
260                        }
261                    } else {
262                        state.left_failures += 1;
263                    }
264                }
265
266                if state.right_failures < max_failures {
267                    if let Some(candidates) = right_candidates.get(&idx) {
268                        if candidates.len() >= self.config.min_coverage {
269                            let consensus = build_consensus_sequence(
270                                candidates,
271                                self.config.min_coverage,
272                                self.config.branching_threshold,
273                                self.config.extension_step,
274                            );
275                            if !consensus.is_empty() {
276                                let n_count = consensus.chars().filter(|&c| c == 'N').count();
277                                let n_ratio = n_count as f64 / consensus.len() as f64;
278
279                                if n_ratio <= self.config.max_n_ratio {
280                                    state.current_seq = format!("{}{}", state.current_seq, consensus);
281                                    state.right_failures = 0;
282                                    any_extended.store(true, std::sync::atomic::Ordering::Relaxed);
283                                } else {
284                                    state.right_failures += 1;
285                                }
286                            } else {
287                                state.right_failures += 1;
288                            }
289                        } else {
290                            state.right_failures += 1;
291                        }
292                    } else {
293                        state.right_failures += 1;
294                    }
295                }
296            });
297
298            if !any_extended.load(std::sync::atomic::Ordering::Relaxed) {
299                break;
300            }
301        }
302
303        let results = states.into_iter().map(|s| {
304            let s = s.into_inner().unwrap();
305            ExtendedContig {
306                name: s.name,
307                extended_seq: s.current_seq,
308            }
309        }).collect();
310
311        Ok(results)
312    }
313
314    #[inline]
315    pub fn extend_all_hybrid(&self, contigs: &[FastaRecord]) -> Result<Vec<ExtendedContig>> {
316        self.extend_contigs(contigs)
317    }
318}
319
320struct ContigState {
321    name: String,
322    current_seq: String,
323    left_failures: usize,
324    right_failures: usize,
325}
326
327fn compute_kmer_hash(kmer: &str) -> Option<u64> {
328    let bytes = kmer.as_bytes();
329    let mut forward = 0u64;
330    let mut reverse = 0u64;
331
332    for (i, &b) in bytes.iter().enumerate() {
333        let base = match b {
334            b'A' | b'a' => 0,
335            b'T' | b't' => 3,
336            b'G' | b'g' => 1,
337            b'C' | b'c' => 2,
338            _ => return None,
339        };
340        forward = (forward << 2) | base;
341        reverse |= (3 - base) << (2 * i);
342    }
343
344    Some(forward.min(reverse))
345}
346
347fn check_kmer_match(read_kmer: &str, contig_kmer: &str) -> (bool, bool) {
348    let is_forward = read_kmer == contig_kmer;
349    let is_revcomp = if is_forward {
350        false
351    } else {
352        reverse_complement(read_kmer) == contig_kmer
353    };
354    (is_forward, is_revcomp)
355}
356
357fn reverse_complement(seq: &str) -> String {
358    seq.chars()
359        .rev()
360        .map(|c| match c.to_ascii_uppercase() {
361            'A' => 'T',
362            'T' => 'A',
363            'G' => 'C',
364            'C' => 'G',
365            _ => 'N',
366        })
367        .collect()
368}
369
370fn build_consensus_sequence(
371    sequences: &[String],
372    min_coverage: usize,
373    branching_threshold: f64,
374    max_len: usize,
375) -> String {
376    if sequences.is_empty() {
377        return String::new();
378    }
379
380    let actual_max_len = sequences.iter().map(|s| s.len()).max().unwrap_or(0).min(max_len);
381    let mut result = String::new();
382
383    for i in 0..actual_max_len {
384
385        let bases: Vec<char> = sequences
386            .iter()
387            .filter_map(|s| s.chars().nth(i))
388            .filter(|&c| matches!(c.to_ascii_uppercase(), 'A' | 'T' | 'G' | 'C'))
389            .collect();
390
391        if bases.len() < min_coverage {
392            break;
393        }
394
395        let mut counts = [0usize; 4];
396        for &b in &bases {
397            match b.to_ascii_uppercase() {
398                'A' => counts[0] += 1,
399                'T' => counts[1] += 1,
400                'G' => counts[2] += 1,
401                'C' => counts[3] += 1,
402                _ => {}
403            }
404        }
405
406        let total = counts.iter().sum::<usize>();
407        let max_idx = counts.iter().enumerate()
408            .max_by_key(|&(_, &c)| c)
409            .map(|(i, _)| i)
410            .unwrap_or(0);
411
412        let mut sorted_counts = counts;
413        sorted_counts.sort_by(|a, b| b.cmp(a));
414        let second_count = sorted_counts[1];
415        let minor_freq = second_count as f64 / total as f64;
416
417        let base = if minor_freq >= branching_threshold {
418            'N'
419        } else {
420            match max_idx {
421                0 => 'A',
422                1 => 'T',
423                2 => 'G',
424                3 => 'C',
425                _ => 'N',
426            }
427        };
428
429        result.push(base);
430    }
431
432    result
433}
434
435pub fn write_extended_contigs(results: &[ExtendedContig], path: &Path) -> Result<()> {
436    let mut writer = BufWriter::new(File::create(path)?);
437
438    for result in results {
439        writeln!(writer, ">{}", result.name)?;
440        writeln!(writer, "{}", result.extended_seq)?;
441    }
442
443    Ok(())
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_compute_kmer_hash() {
452
453        let h1 = compute_kmer_hash("ATGC").unwrap();
454        let h2 = compute_kmer_hash("ATGC").unwrap();
455        assert_eq!(h1, h2);
456
457        let h3 = compute_kmer_hash("GCAT").unwrap();
458        assert_eq!(h1, h3);
459
460        assert!(compute_kmer_hash("ATNG").is_none());
461    }
462
463    #[test]
464    fn test_reverse_complement() {
465        assert_eq!(reverse_complement("ATGC"), "GCAT");
466        assert_eq!(reverse_complement("AAAA"), "TTTT");
467        assert_eq!(reverse_complement(""), "");
468    }
469
470    #[test]
471    fn test_build_consensus() {
472        let seqs = vec![
473            "ATGC".to_string(),
474            "ATGC".to_string(),
475            "ATGC".to_string(),
476        ];
477        let consensus = build_consensus_sequence(&seqs, 2, 0.2, 100);
478        assert_eq!(consensus, "ATGC");
479    }
480
481}