1
2use anyhow::Result;
3use rayon::prelude::*;
4use rustc_hash::FxHashMap;
5use std::fs::File;
6use std::io::{BufWriter, Write};
7use std::path::Path;
8use std::sync::Mutex;
9
10use crate::seqio::{FastaRecord, FastqFile};
11
12#[derive(Clone)]
13pub struct ExtenderConfig {
14
15 pub kmer_size: usize,
16
17 pub num_edge_kmers: usize,
18
19 pub min_coverage: usize,
20
21 pub branching_threshold: f64,
22
23 pub max_n_ratio: f64,
24
25 pub extension_step: usize,
26
27 pub max_consecutive_failures: usize,
28}
29
30impl Default for ExtenderConfig {
31 fn default() -> Self {
32 Self {
33 kmer_size: 21,
34 num_edge_kmers: 5,
35 min_coverage: 2,
36 branching_threshold: 0.2,
37 max_n_ratio: 0.05,
38 extension_step: 200,
39 max_consecutive_failures: 2,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
45pub struct ExtendedContig {
46
47 pub name: String,
48
49 pub extended_seq: String,
50}
51
52pub struct ContigExtender {
53 config: ExtenderConfig,
54 reads: Vec<String>,
55}
56
57impl ContigExtender {
58
59 pub fn new(config: ExtenderConfig) -> Self {
60 Self {
61 config,
62 reads: Vec::new(),
63 }
64 }
65
66 pub fn load_reads(&mut self, r1_path: &Path, r2_path: &Path) -> Result<()> {
67 eprintln!("Loading reads into memory...");
68
69 let r1_owned = r1_path.to_path_buf();
70 let r2_owned = r2_path.to_path_buf();
71
72 let handle_r1 = std::thread::spawn(move || -> Result<Vec<String>> {
73 let mut reads = Vec::new();
74 let mut reader = FastqFile::open(&r1_owned)?;
75 while let Some(record) = reader.read_next()? {
76 reads.push(record.seq);
77 }
78 Ok(reads)
79 });
80
81 let handle_r2 = std::thread::spawn(move || -> Result<Vec<String>> {
82 let mut reads = Vec::new();
83 let mut reader = FastqFile::open(&r2_owned)?;
84 while let Some(record) = reader.read_next()? {
85 reads.push(record.seq);
86 }
87 Ok(reads)
88 });
89
90 let reads_r1 = handle_r1.join().map_err(|_| anyhow::anyhow!("R1 load thread panicked"))??;
91 let reads_r2 = handle_r2.join().map_err(|_| anyhow::anyhow!("R2 load thread panicked"))??;
92
93 self.reads = reads_r1;
94 self.reads.extend(reads_r2);
95
96 eprintln!("Loaded {} reads into memory", self.reads.len());
97 Ok(())
98 }
99
100 pub fn extend_contigs(&self, contigs: &[FastaRecord]) -> Result<Vec<ExtendedContig>> {
101 let k = self.config.kmer_size;
102 let max_failures = self.config.max_consecutive_failures;
103
104 let states: Vec<Mutex<ContigState>> = contigs.iter().map(|c| {
105 Mutex::new(ContigState {
106 name: c.name.clone(),
107 current_seq: c.seq.clone(),
108 left_failures: 0,
109 right_failures: 0,
110 })
111 }).collect();
112
113 loop {
114
115 let active_indices: Vec<usize> = states.iter().enumerate()
116 .filter(|(_, s)| {
117 let s = s.lock().unwrap();
118 s.left_failures < max_failures || s.right_failures < max_failures
119 })
120 .map(|(i, _)| i)
121 .collect();
122
123 if active_indices.is_empty() {
124 break;
125 }
126
127 let mut edge_kmers: FxHashMap<u64, Vec<(usize, bool, usize)>> = FxHashMap::default();
128
129 for &idx in &active_indices {
130 let state = states[idx].lock().unwrap();
131 let seq = &state.current_seq;
132 if seq.len() < k {
133 continue;
134 }
135
136 if state.left_failures < max_failures {
137 for offset in 0..self.config.num_edge_kmers.min(seq.len() - k + 1) {
138 if let Some(hash) = compute_kmer_hash(&seq[offset..offset+k]) {
139 edge_kmers.entry(hash).or_default().push((idx, true, offset));
140 }
141 }
142 }
143
144 if state.right_failures < max_failures {
145 let seq_len = seq.len();
146 for offset in 0..self.config.num_edge_kmers.min(seq.len() - k + 1) {
147 let start = seq_len - k - offset;
148 if let Some(hash) = compute_kmer_hash(&seq[start..start+k]) {
149 edge_kmers.entry(hash).or_default().push((idx, false, offset));
150 }
151 }
152 }
153 }
154
155 let left_candidates: Mutex<FxHashMap<usize, Vec<String>>> = Mutex::new(FxHashMap::default());
156 let right_candidates: Mutex<FxHashMap<usize, Vec<String>>> = Mutex::new(FxHashMap::default());
157
158 self.reads.par_iter().for_each(|read_seq| {
159 if read_seq.len() < k {
160 return;
161 }
162
163 let mut local_left: FxHashMap<usize, Vec<String>> = FxHashMap::default();
164 let mut local_right: FxHashMap<usize, Vec<String>> = FxHashMap::default();
165
166 for i in 0..=(read_seq.len() - k) {
167 let kmer_seq = &read_seq[i..i+k];
168 if let Some(hash) = compute_kmer_hash(kmer_seq) {
169 if let Some(matches) = edge_kmers.get(&hash) {
170 for &(contig_idx, is_left, edge_offset) in matches {
171 let state = states[contig_idx].lock().unwrap();
172 let contig_kmer = if is_left {
173 &state.current_seq[edge_offset..edge_offset+k]
174 } else {
175 let clen = state.current_seq.len();
176 &state.current_seq[clen-k-edge_offset..clen-edge_offset]
177 };
178
179 let (is_forward, is_revcomp) = check_kmer_match(kmer_seq, contig_kmer);
180 drop(state);
181
182 if is_left {
183 if is_forward && i > edge_offset {
184 let prefix = &read_seq[..i - edge_offset];
185 if !prefix.is_empty() {
186 let ext: String = prefix.chars().rev().collect();
187 local_left.entry(contig_idx).or_default().push(ext);
188 }
189 } else if is_revcomp && i + k + edge_offset < read_seq.len() {
190 let suffix = &read_seq[i+k+edge_offset..];
191 if !suffix.is_empty() {
192 let ext = reverse_complement(suffix);
193 local_left.entry(contig_idx).or_default().push(ext);
194 }
195 }
196 } else if is_forward && i + k + edge_offset < read_seq.len() {
197 let suffix = &read_seq[i+k+edge_offset..];
198 if !suffix.is_empty() {
199 local_right.entry(contig_idx).or_default().push(suffix.to_string());
200 }
201 } else if is_revcomp && i > edge_offset {
202 let prefix = &read_seq[..i - edge_offset];
203 if !prefix.is_empty() {
204 let ext = reverse_complement(prefix);
205 local_right.entry(contig_idx).or_default().push(ext);
206 }
207 }
208 }
209 }
210 }
211 }
212
213 if !local_left.is_empty() {
214 let mut global = left_candidates.lock().unwrap();
215 for (idx, candidates) in local_left {
216 global.entry(idx).or_default().extend(candidates);
217 }
218 }
219 if !local_right.is_empty() {
220 let mut global = right_candidates.lock().unwrap();
221 for (idx, candidates) in local_right {
222 global.entry(idx).or_default().extend(candidates);
223 }
224 }
225 });
226
227 let left_candidates = left_candidates.into_inner().unwrap();
228 let right_candidates = right_candidates.into_inner().unwrap();
229
230 let any_extended = std::sync::atomic::AtomicBool::new(false);
231
232 active_indices.par_iter().for_each(|&idx| {
233 let mut state = states[idx].lock().unwrap();
234
235 if state.left_failures < max_failures {
236 if let Some(candidates) = left_candidates.get(&idx) {
237 if candidates.len() >= self.config.min_coverage {
238 let consensus = build_consensus_sequence(
239 candidates,
240 self.config.min_coverage,
241 self.config.branching_threshold,
242 self.config.extension_step,
243 );
244 if !consensus.is_empty() {
245 let n_count = consensus.chars().filter(|&c| c == 'N').count();
246 let n_ratio = n_count as f64 / consensus.len() as f64;
247
248 if n_ratio <= self.config.max_n_ratio {
249 state.current_seq = format!("{}{}", consensus, state.current_seq);
250 state.left_failures = 0;
251 any_extended.store(true, std::sync::atomic::Ordering::Relaxed);
252 } else {
253 state.left_failures += 1;
254 }
255 } else {
256 state.left_failures += 1;
257 }
258 } else {
259 state.left_failures += 1;
260 }
261 } else {
262 state.left_failures += 1;
263 }
264 }
265
266 if state.right_failures < max_failures {
267 if let Some(candidates) = right_candidates.get(&idx) {
268 if candidates.len() >= self.config.min_coverage {
269 let consensus = build_consensus_sequence(
270 candidates,
271 self.config.min_coverage,
272 self.config.branching_threshold,
273 self.config.extension_step,
274 );
275 if !consensus.is_empty() {
276 let n_count = consensus.chars().filter(|&c| c == 'N').count();
277 let n_ratio = n_count as f64 / consensus.len() as f64;
278
279 if n_ratio <= self.config.max_n_ratio {
280 state.current_seq = format!("{}{}", state.current_seq, consensus);
281 state.right_failures = 0;
282 any_extended.store(true, std::sync::atomic::Ordering::Relaxed);
283 } else {
284 state.right_failures += 1;
285 }
286 } else {
287 state.right_failures += 1;
288 }
289 } else {
290 state.right_failures += 1;
291 }
292 } else {
293 state.right_failures += 1;
294 }
295 }
296 });
297
298 if !any_extended.load(std::sync::atomic::Ordering::Relaxed) {
299 break;
300 }
301 }
302
303 let results = states.into_iter().map(|s| {
304 let s = s.into_inner().unwrap();
305 ExtendedContig {
306 name: s.name,
307 extended_seq: s.current_seq,
308 }
309 }).collect();
310
311 Ok(results)
312 }
313
314 #[inline]
315 pub fn extend_all_hybrid(&self, contigs: &[FastaRecord]) -> Result<Vec<ExtendedContig>> {
316 self.extend_contigs(contigs)
317 }
318}
319
320struct ContigState {
321 name: String,
322 current_seq: String,
323 left_failures: usize,
324 right_failures: usize,
325}
326
327fn compute_kmer_hash(kmer: &str) -> Option<u64> {
328 let bytes = kmer.as_bytes();
329 let mut forward = 0u64;
330 let mut reverse = 0u64;
331
332 for (i, &b) in bytes.iter().enumerate() {
333 let base = match b {
334 b'A' | b'a' => 0,
335 b'T' | b't' => 3,
336 b'G' | b'g' => 1,
337 b'C' | b'c' => 2,
338 _ => return None,
339 };
340 forward = (forward << 2) | base;
341 reverse |= (3 - base) << (2 * i);
342 }
343
344 Some(forward.min(reverse))
345}
346
347fn check_kmer_match(read_kmer: &str, contig_kmer: &str) -> (bool, bool) {
348 let is_forward = read_kmer == contig_kmer;
349 let is_revcomp = if is_forward {
350 false
351 } else {
352 reverse_complement(read_kmer) == contig_kmer
353 };
354 (is_forward, is_revcomp)
355}
356
357fn reverse_complement(seq: &str) -> String {
358 seq.chars()
359 .rev()
360 .map(|c| match c.to_ascii_uppercase() {
361 'A' => 'T',
362 'T' => 'A',
363 'G' => 'C',
364 'C' => 'G',
365 _ => 'N',
366 })
367 .collect()
368}
369
370fn build_consensus_sequence(
371 sequences: &[String],
372 min_coverage: usize,
373 branching_threshold: f64,
374 max_len: usize,
375) -> String {
376 if sequences.is_empty() {
377 return String::new();
378 }
379
380 let actual_max_len = sequences.iter().map(|s| s.len()).max().unwrap_or(0).min(max_len);
381 let mut result = String::new();
382
383 for i in 0..actual_max_len {
384
385 let bases: Vec<char> = sequences
386 .iter()
387 .filter_map(|s| s.chars().nth(i))
388 .filter(|&c| matches!(c.to_ascii_uppercase(), 'A' | 'T' | 'G' | 'C'))
389 .collect();
390
391 if bases.len() < min_coverage {
392 break;
393 }
394
395 let mut counts = [0usize; 4];
396 for &b in &bases {
397 match b.to_ascii_uppercase() {
398 'A' => counts[0] += 1,
399 'T' => counts[1] += 1,
400 'G' => counts[2] += 1,
401 'C' => counts[3] += 1,
402 _ => {}
403 }
404 }
405
406 let total = counts.iter().sum::<usize>();
407 let max_idx = counts.iter().enumerate()
408 .max_by_key(|&(_, &c)| c)
409 .map(|(i, _)| i)
410 .unwrap_or(0);
411
412 let mut sorted_counts = counts;
413 sorted_counts.sort_by(|a, b| b.cmp(a));
414 let second_count = sorted_counts[1];
415 let minor_freq = second_count as f64 / total as f64;
416
417 let base = if minor_freq >= branching_threshold {
418 'N'
419 } else {
420 match max_idx {
421 0 => 'A',
422 1 => 'T',
423 2 => 'G',
424 3 => 'C',
425 _ => 'N',
426 }
427 };
428
429 result.push(base);
430 }
431
432 result
433}
434
435pub fn write_extended_contigs(results: &[ExtendedContig], path: &Path) -> Result<()> {
436 let mut writer = BufWriter::new(File::create(path)?);
437
438 for result in results {
439 writeln!(writer, ">{}", result.name)?;
440 writeln!(writer, "{}", result.extended_seq)?;
441 }
442
443 Ok(())
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_compute_kmer_hash() {
452
453 let h1 = compute_kmer_hash("ATGC").unwrap();
454 let h2 = compute_kmer_hash("ATGC").unwrap();
455 assert_eq!(h1, h2);
456
457 let h3 = compute_kmer_hash("GCAT").unwrap();
458 assert_eq!(h1, h3);
459
460 assert!(compute_kmer_hash("ATNG").is_none());
461 }
462
463 #[test]
464 fn test_reverse_complement() {
465 assert_eq!(reverse_complement("ATGC"), "GCAT");
466 assert_eq!(reverse_complement("AAAA"), "TTTT");
467 assert_eq!(reverse_complement(""), "");
468 }
469
470 #[test]
471 fn test_build_consensus() {
472 let seqs = vec![
473 "ATGC".to_string(),
474 "ATGC".to_string(),
475 "ATGC".to_string(),
476 ];
477 let consensus = build_consensus_sequence(&seqs, 2, 0.2, 100);
478 assert_eq!(consensus, "ATGC");
479 }
480
481}