1use anyhow::{Context, Result};
2use crossbeam::channel::{bounded, Receiver};
3use log::*;
4use num_cpus;
5use rayon::prelude::*;
6use rust_htslib::bam::{IndexedReader, Read};
7use rust_lapper::Lapper;
8use std::{
9 convert::TryInto,
10 path::PathBuf,
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc,
14 },
15 thread,
16};
17
18use super::intervals;
19use super::types::{RegionProcessor, BYTES_IN_A_GIGABYTE, CHANNEL_SIZE_MODIFIER, CHUNKSIZE};
20
21#[derive(Debug)]
23pub struct ParGranges<R: 'static + RegionProcessor + Send + Sync> {
24 reads: PathBuf,
25 ref_fasta: Option<PathBuf>,
26 regions_bed: Option<PathBuf>,
27 regions_bcf: Option<PathBuf>,
28 merge_regions: bool,
29 threads: usize,
30 chunksize: u32,
31 channel_size_modifier: f64,
32 pool: rayon::ThreadPool,
33 processor: R,
34}
35
36impl<R: RegionProcessor + Send + Sync> ParGranges<R> {
37 #[allow(clippy::too_many_arguments)]
39 pub fn new(
40 reads: PathBuf,
41 ref_fasta: Option<PathBuf>,
42 regions_bed: Option<PathBuf>,
43 regions_bcf: Option<PathBuf>,
44 merge_regions: bool,
45 threads: Option<usize>,
46 chunksize: Option<u32>,
47 channel_size_modifier: Option<f64>,
48 processor: R,
49 ) -> Self {
50 let requested_threads = threads.unwrap_or_else(num_cpus::get);
51 let threads = std::cmp::max(requested_threads, 1);
52 info!("Using {} worker threads.", threads);
53
54 let pool = rayon::ThreadPoolBuilder::new()
55 .num_threads(threads)
56 .stack_size(2 * 1024 * 1024) .build()
58 .expect("Failed to build Rayon thread pool");
59
60 Self {
61 reads,
62 ref_fasta,
63 regions_bed,
64 regions_bcf,
65 merge_regions,
66 threads,
67 chunksize: chunksize.unwrap_or(CHUNKSIZE),
68 channel_size_modifier: channel_size_modifier.unwrap_or(CHANNEL_SIZE_MODIFIER),
69 pool,
70 processor,
71 }
72 }
73
74 pub fn process(self) -> Result<Receiver<R::P>> {
76 let ParGranges {
77 reads,
78 ref_fasta,
79 regions_bed,
80 regions_bcf,
81 merge_regions,
82 threads,
83 chunksize,
84 channel_size_modifier,
85 pool,
86 processor,
87 } = self;
88
89 let item_size = std::mem::size_of::<R::P>().max(1);
90 let channel_size: usize =
91 ((BYTES_IN_A_GIGABYTE as f64 * channel_size_modifier).floor() as usize / item_size)
92 .saturating_mul(threads);
93 info!(
94 "Creating channel of length {} (* {} bytes per item)",
95 channel_size, item_size
96 );
97
98 let engine = Engine {
99 reads,
100 ref_fasta,
101 regions_bed,
102 regions_bcf,
103 merge_regions,
104 threads,
105 chunksize,
106 processor,
107 };
108
109 let (sender, receiver) = bounded::<R::P>(channel_size.max(1));
110 thread::spawn(move || {
111 pool.install(move || {
112 if let Err(err) = engine.run(sender) {
113 error!("ParGranges terminated with error: {}", err);
114 }
115 });
116 });
117 Ok(receiver)
118 }
119}
120
121struct Engine<R: RegionProcessor + Send + Sync> {
122 reads: PathBuf,
123 ref_fasta: Option<PathBuf>,
124 regions_bed: Option<PathBuf>,
125 regions_bcf: Option<PathBuf>,
126 merge_regions: bool,
127 threads: usize,
128 chunksize: u32,
129 processor: R,
130}
131
132#[derive(Clone, Copy, Debug)]
133struct RegionTask {
134 tid: u32,
135 start: u32,
136 stop: u32,
137}
138
139fn materialize_region_tasks(
140 intervals: Vec<Lapper<u32, ()>>,
141 target_info: &[(u32, String)],
142 tile: u32,
143 reserve: usize,
144) -> Vec<RegionTask> {
145 let tile = tile.max(1);
146 let mut work = Vec::with_capacity(reserve);
147 let target_len = target_info.len();
148
149 for (tid_idx, contig_intervals) in intervals.into_iter().enumerate() {
150 if tid_idx >= target_len {
151 break;
152 }
153
154 let (span, _) = target_info[tid_idx];
155 if span == 0 {
156 continue;
157 }
158
159 let tid = tid_idx as u32;
160 for interval in contig_intervals.iter() {
161 let mut cursor = interval.start;
162 while cursor < interval.stop {
163 let stop = std::cmp::min(cursor + tile, interval.stop);
164 if stop > cursor {
165 work.push(RegionTask {
166 tid,
167 start: cursor,
168 stop,
169 });
170 }
171 cursor = stop;
172 }
173 }
174 }
175
176 work
177}
178
179impl<R: RegionProcessor + Send + Sync> Engine<R> {
180 fn run(self, sender: crossbeam::channel::Sender<R::P>) -> Result<()> {
181 info!("Reading from {:?}", self.reads);
182 let mut reader = IndexedReader::from_path(&self.reads)
183 .with_context(|| format!("Failed to open BAM/CRAM {}", self.reads.display()))?;
184 if let Err(e) = reader.set_threads(self.threads) {
185 error!("Failed to set thread count to {}: {}", self.threads, e);
186 }
187 if let Some(ref_fasta) = &self.ref_fasta {
188 reader
189 .set_reference(ref_fasta)
190 .with_context(|| format!("Failed to set reference {}", ref_fasta.display()))?;
191 }
192 let header = reader.header().to_owned();
193 let target_info: Vec<(u32, String)> = (0..header.target_count())
194 .map(|tid| {
195 let len = header
196 .target_len(tid)
197 .and_then(|len| len.try_into().ok())
198 .unwrap_or(0);
199 let name = std::str::from_utf8(header.tid2name(tid))
200 .unwrap_or("unknown")
201 .to_string();
202 (len, name)
203 })
204 .collect();
205
206 let bed_intervals = match &self.regions_bed {
207 Some(path) => Some(intervals::bed_to_intervals(
208 &header,
209 path,
210 self.merge_regions,
211 )?),
212 None => None,
213 };
214 let bcf_intervals = match &self.regions_bcf {
215 Some(path) => Some(intervals::bcf_to_intervals(
216 &header,
217 path,
218 self.merge_regions,
219 )?),
220 None => None,
221 };
222
223 let restricted = match (bed_intervals, bcf_intervals) {
224 (Some(bed), Some(bcf)) => {
225 Some(intervals::merge_intervals(bed, bcf, self.merge_regions))
226 }
227 (Some(bed), None) => Some(bed),
228 (None, Some(bcf)) => Some(bcf),
229 (None, None) => None,
230 };
231
232 let intervals = match restricted {
233 Some(ivs) => ivs,
234 None => intervals::header_to_intervals(&header, self.chunksize)?,
235 };
236
237 let tile = self.chunksize.max(1);
238
239 let estimated_total_chunks: usize = target_info
240 .iter()
241 .filter(|(len, _)| *len > 0)
242 .map(|(len, _)| (((*len - 1) / tile) + 1) as usize)
243 .sum();
244
245 let work = materialize_region_tasks(intervals, &target_info, tile, estimated_total_chunks);
246
247 if work.is_empty() {
248 info!("No intervals scheduled for processing; exiting early");
249 return Ok(());
250 }
251
252 let total_chunks = work.len();
253 let log_step = std::cmp::max(1, total_chunks / 10);
254 trace!(
255 "Scheduling {} region tasks (chunk size {}) across {} worker threads",
256 total_chunks,
257 tile,
258 self.threads
259 );
260
261 let processed_chunks = AtomicUsize::new(0);
262 let target_info = Arc::new(target_info);
263 let total_chunks_f = total_chunks as f64;
264
265 let worker_scale = (self.threads * 8).max(1);
266 let scheduling_granularity =
267 ((total_chunks + worker_scale.saturating_sub(1)) / worker_scale).max(1);
268
269 work.into_par_iter()
270 .with_min_len(1)
271 .with_max_len(scheduling_granularity)
272 .for_each_init(
273 || (sender.clone(), Arc::clone(&target_info)),
274 |(snd, target_info), task| {
275 trace!(
276 "Processing TID {} interval {}-{}",
277 task.tid,
278 task.start,
279 task.stop
280 );
281
282 let results = self
283 .processor
284 .process_region(task.tid, task.start, task.stop);
285 for item in results {
286 if snd.send(item).is_err() {
287 warn!("Channel closed; terminating region processing early");
288 return;
289 }
290 }
291
292 let completed = processed_chunks.fetch_add(1, Ordering::Relaxed) + 1;
293 if completed == total_chunks || completed % log_step == 0 {
294 let (_, tid_name) = &target_info[task.tid as usize];
295 let percent = (completed as f64 / total_chunks_f) * 100.0;
296 info!(
297 "Processed {:.1}% ({} / {} chunks) – {}:{}-{}",
298 percent, completed, total_chunks, tid_name, task.start, task.stop
299 );
300 }
301 },
302 );
303
304 Ok(())
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use bio::io::bed;
312 use proptest::prelude::*;
313 use rust_htslib::{bam, bcf};
314 use rust_lapper::{Interval, Lapper};
315 use smartstring::SmartString;
316 use std::collections::{HashMap, HashSet};
317 use tempfile::tempdir;
318
319 use crate::engine::position::pileup_position::PileupPosition;
320 use crate::engine::position::Position;
321
322 #[test]
323 fn region_task_materialization_respects_chunk_size() {
324 let intervals = vec![Lapper::new(vec![Interval {
325 start: 0,
326 stop: 120,
327 val: (),
328 }])];
329 let target_info = vec![(120_u32, "chr1".to_string())];
330
331 let tasks = super::materialize_region_tasks(intervals, &target_info, 50, 0);
332
333 assert_eq!(tasks.len(), 3);
334 assert_eq!(tasks[0].tid, 0);
335 assert_eq!(tasks[0].start, 0);
336 assert_eq!(tasks[0].stop, 50);
337 assert_eq!(tasks[1].start, 50);
338 assert_eq!(tasks[1].stop, 100);
339 assert_eq!(tasks[2].start, 100);
340 assert_eq!(tasks[2].stop, 120);
341 }
342
343 struct TestProcessor;
344
345 impl RegionProcessor for TestProcessor {
346 type P = PileupPosition;
347
348 fn process_region(&self, tid: u32, start: u32, stop: u32) -> Vec<Self::P> {
349 (start..stop)
350 .map(|pos| {
351 let chr = SmartString::from(&tid.to_string());
352 PileupPosition::new(chr, pos)
353 })
354 .collect()
355 }
356 }
357
358 prop_compose! {
359 fn arb_iv_start(max_iv: u64)(start in 0..max_iv/2) -> u64 { start }
360 }
361 prop_compose! {
362 fn arb_iv_size(max_iv: u64)(size in 1..max_iv/2) -> u64 { size }
363 }
364 prop_compose! {
365 fn arb_iv(max_iv: u64)(start in arb_iv_start(max_iv), size in arb_iv_size(max_iv)) -> Interval<u64, ()> {
366 Interval { start, stop: start + size, val: () }
367 }
368 }
369 fn arb_ivs(
370 max_iv: u64,
371 max_ivs: usize,
372 ) -> impl Strategy<Value = (Vec<Interval<u64, ()>>, u64, u64)> {
373 prop::collection::vec(arb_iv(max_iv), 0..max_ivs).prop_map(|vec| {
374 let mut furthest_right = 0;
375 let lapper = Lapper::new(vec.clone());
376 let expected = lapper.cov();
377 for iv in vec.iter() {
378 furthest_right = furthest_right.max(iv.stop);
379 }
380 (vec, expected, furthest_right)
381 })
382 }
383 fn arb_chrs(
384 max_chr: usize,
385 max_iv: u64,
386 max_ivs: usize,
387 ) -> impl Strategy<Value = Vec<(Vec<Interval<u64, ()>>, u64, u64)>> {
388 prop::collection::vec(arb_ivs(max_iv, max_ivs), 0..max_chr)
389 }
390
391 proptest! {
392 #[test]
393 fn interval_set(
394 chromosomes in arb_chrs(4, 10_000, 1_000),
395 chunksize in any::<u32>(),
396 cpus in 0..num_cpus::get(),
397 use_bed in any::<bool>(),
398 use_vcf in any::<bool>(),
399 ) {
400 let tempdir = tempdir().unwrap();
401 let bam_path = tempdir.path().join("test.bam");
402 let bed_path = tempdir.path().join("test.bed");
403 let vcf_path = tempdir.path().join("test.vcf");
404
405 let mut header = bam::header::Header::new();
406 for (i, chr) in chromosomes.iter().enumerate() {
407 let mut chr_rec = bam::header::HeaderRecord::new(b"SQ");
408 chr_rec.push_tag(b"SN", &i.to_string());
409 chr_rec.push_tag(b"LN", &chr.2.to_string());
410 header.push_record(&chr_rec);
411 }
412 let writer = bam::Writer::from_path(&bam_path, &header, bam::Format::Bam).unwrap();
413 drop(writer);
414 bam::index::build(&bam_path, None, bam::index::Type::Bai, 1).unwrap();
415
416 let mut bed_writer = bed::Writer::to_file(&bed_path).unwrap();
417 for (i, chr) in chromosomes.iter().enumerate() {
418 for iv in chr.0.iter() {
419 let mut record = bed::Record::new();
420 record.set_start(iv.start);
421 record.set_end(iv.stop);
422 record.set_chrom(&i.to_string());
423 record.set_score(&0.to_string());
424 bed_writer.write(&record).unwrap();
425 }
426 }
427 drop(bed_writer);
428
429 let mut vcf_truth = HashMap::new();
430 let mut vcf_header = bcf::header::Header::new();
431 for (i, chr) in chromosomes.iter().enumerate() {
432 vcf_header.push_record(
433 format!("##contig=<ID={},length={}>", i, chr.2).as_bytes(),
434 );
435 }
436 let mut vcf_writer = bcf::Writer::from_path(&vcf_path, &vcf_header, true, bcf::Format::Vcf).unwrap();
437 let mut record = vcf_writer.empty_record();
438 for (i, chr) in chromosomes.iter().enumerate() {
439 record.set_rid(Some(i as u32));
440 let counter = vcf_truth.entry(i).or_insert(0);
441 let mut seen = HashSet::new();
442 for iv in chr.0.iter() {
443 if seen.insert(iv.start) {
444 *counter += 1;
445 }
446 record.set_pos(iv.start as i64);
447 vcf_writer.write(&record).unwrap();
448 }
449 }
450 drop(vcf_writer);
451
452 let par_granges_runner = ParGranges::new(
453 bam_path,
454 None,
455 if use_bed { Some(bed_path) } else { None },
456 if use_vcf { Some(vcf_path) } else { None },
457 true,
458 Some((cpus + 1).max(1)),
459 Some(chunksize.max(1)),
460 Some(0.002),
461 TestProcessor,
462 );
463 let receiver = par_granges_runner.process().unwrap();
464 let mut chrom_counts = HashMap::new();
465 receiver.into_iter().for_each(|p: PileupPosition| {
466 *chrom_counts.entry(p.ref_seq.parse::<usize>().unwrap()).or_insert(0u64) += 1;
467 });
468
469 for (chrom, positions) in chrom_counts.iter() {
470 if use_bed && !use_vcf {
471 prop_assert_eq!(chromosomes[*chrom].1, *positions);
472 } else if use_bed && use_vcf {
473 prop_assert_eq!(chromosomes[*chrom].1, *positions);
474 } else if use_vcf && !use_bed {
475 prop_assert_eq!(vcf_truth.get(chrom).unwrap(), positions);
476 } else {
477 prop_assert_eq!(chromosomes[*chrom].2, *positions);
478 }
479 }
480 }
481 }
482}