redicat_lib/engine/par_granges/
scheduler.rs

1use anyhow::{Context, Result};
2use crossbeam::channel::{bounded, Receiver};
3use log::*;
4use num_cpus;
5use rayon::prelude::*;
6use rust_htslib::bam::{IndexedReader, Read};
7use rust_lapper::Lapper;
8use std::{
9    convert::TryInto,
10    path::PathBuf,
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc,
14    },
15    thread,
16};
17
18use super::intervals;
19use super::types::{RegionProcessor, BYTES_IN_A_GIGABYTE, CHANNEL_SIZE_MODIFIER, CHUNKSIZE};
20
21/// Parallel BAM/CRAM region executor driven by [`RegionProcessor`] implementations.
22#[derive(Debug)]
23pub struct ParGranges<R: 'static + RegionProcessor + Send + Sync> {
24    reads: PathBuf,
25    ref_fasta: Option<PathBuf>,
26    regions_bed: Option<PathBuf>,
27    regions_bcf: Option<PathBuf>,
28    merge_regions: bool,
29    threads: usize,
30    chunksize: u32,
31    channel_size_modifier: f64,
32    pool: rayon::ThreadPool,
33    processor: R,
34}
35
36impl<R: RegionProcessor + Send + Sync> ParGranges<R> {
37    /// Create a new [`ParGranges`] executor.
38    #[allow(clippy::too_many_arguments)]
39    pub fn new(
40        reads: PathBuf,
41        ref_fasta: Option<PathBuf>,
42        regions_bed: Option<PathBuf>,
43        regions_bcf: Option<PathBuf>,
44        merge_regions: bool,
45        threads: Option<usize>,
46        chunksize: Option<u32>,
47        channel_size_modifier: Option<f64>,
48        processor: R,
49    ) -> Self {
50        let requested_threads = threads.unwrap_or_else(num_cpus::get);
51        let threads = std::cmp::max(requested_threads, 1);
52        info!("Using {} worker threads.", threads);
53
54        let pool = rayon::ThreadPoolBuilder::new()
55            .num_threads(threads)
56            .stack_size(2 * 1024 * 1024) // 2 MB stack per thread
57            .build()
58            .expect("Failed to build Rayon thread pool");
59
60        Self {
61            reads,
62            ref_fasta,
63            regions_bed,
64            regions_bcf,
65            merge_regions,
66            threads,
67            chunksize: chunksize.unwrap_or(CHUNKSIZE),
68            channel_size_modifier: channel_size_modifier.unwrap_or(CHANNEL_SIZE_MODIFIER),
69            pool,
70            processor,
71        }
72    }
73
74    /// Launch parallel processing for all configured regions.
75    pub fn process(self) -> Result<Receiver<R::P>> {
76        let ParGranges {
77            reads,
78            ref_fasta,
79            regions_bed,
80            regions_bcf,
81            merge_regions,
82            threads,
83            chunksize,
84            channel_size_modifier,
85            pool,
86            processor,
87        } = self;
88
89        let item_size = std::mem::size_of::<R::P>().max(1);
90        let channel_size: usize =
91            ((BYTES_IN_A_GIGABYTE as f64 * channel_size_modifier).floor() as usize / item_size)
92                .saturating_mul(threads);
93        info!(
94            "Creating channel of length {} (* {} bytes per item)",
95            channel_size, item_size
96        );
97
98        let engine = Engine {
99            reads,
100            ref_fasta,
101            regions_bed,
102            regions_bcf,
103            merge_regions,
104            threads,
105            chunksize,
106            processor,
107        };
108
109        let (sender, receiver) = bounded::<R::P>(channel_size.max(1));
110        thread::spawn(move || {
111            pool.install(move || {
112                if let Err(err) = engine.run(sender) {
113                    error!("ParGranges terminated with error: {}", err);
114                }
115            });
116        });
117        Ok(receiver)
118    }
119}
120
121struct Engine<R: RegionProcessor + Send + Sync> {
122    reads: PathBuf,
123    ref_fasta: Option<PathBuf>,
124    regions_bed: Option<PathBuf>,
125    regions_bcf: Option<PathBuf>,
126    merge_regions: bool,
127    threads: usize,
128    chunksize: u32,
129    processor: R,
130}
131
132#[derive(Clone, Copy, Debug)]
133struct RegionTask {
134    tid: u32,
135    start: u32,
136    stop: u32,
137}
138
139fn materialize_region_tasks(
140    intervals: Vec<Lapper<u32, ()>>,
141    target_info: &[(u32, String)],
142    tile: u32,
143    reserve: usize,
144) -> Vec<RegionTask> {
145    let tile = tile.max(1);
146    let mut work = Vec::with_capacity(reserve);
147    let target_len = target_info.len();
148
149    for (tid_idx, contig_intervals) in intervals.into_iter().enumerate() {
150        if tid_idx >= target_len {
151            break;
152        }
153
154        let (span, _) = target_info[tid_idx];
155        if span == 0 {
156            continue;
157        }
158
159        let tid = tid_idx as u32;
160        for interval in contig_intervals.iter() {
161            let mut cursor = interval.start;
162            while cursor < interval.stop {
163                let stop = std::cmp::min(cursor + tile, interval.stop);
164                if stop > cursor {
165                    work.push(RegionTask {
166                        tid,
167                        start: cursor,
168                        stop,
169                    });
170                }
171                cursor = stop;
172            }
173        }
174    }
175
176    work
177}
178
179impl<R: RegionProcessor + Send + Sync> Engine<R> {
180    fn run(self, sender: crossbeam::channel::Sender<R::P>) -> Result<()> {
181        info!("Reading from {:?}", self.reads);
182        let mut reader = IndexedReader::from_path(&self.reads)
183            .with_context(|| format!("Failed to open BAM/CRAM {}", self.reads.display()))?;
184        if let Err(e) = reader.set_threads(self.threads) {
185            error!("Failed to set thread count to {}: {}", self.threads, e);
186        }
187        if let Some(ref_fasta) = &self.ref_fasta {
188            reader
189                .set_reference(ref_fasta)
190                .with_context(|| format!("Failed to set reference {}", ref_fasta.display()))?;
191        }
192        let header = reader.header().to_owned();
193        let target_info: Vec<(u32, String)> = (0..header.target_count())
194            .map(|tid| {
195                let len = header
196                    .target_len(tid)
197                    .and_then(|len| len.try_into().ok())
198                    .unwrap_or(0);
199                let name = std::str::from_utf8(header.tid2name(tid))
200                    .unwrap_or("unknown")
201                    .to_string();
202                (len, name)
203            })
204            .collect();
205
206        let bed_intervals = match &self.regions_bed {
207            Some(path) => Some(intervals::bed_to_intervals(
208                &header,
209                path,
210                self.merge_regions,
211            )?),
212            None => None,
213        };
214        let bcf_intervals = match &self.regions_bcf {
215            Some(path) => Some(intervals::bcf_to_intervals(
216                &header,
217                path,
218                self.merge_regions,
219            )?),
220            None => None,
221        };
222
223        let restricted = match (bed_intervals, bcf_intervals) {
224            (Some(bed), Some(bcf)) => {
225                Some(intervals::merge_intervals(bed, bcf, self.merge_regions))
226            }
227            (Some(bed), None) => Some(bed),
228            (None, Some(bcf)) => Some(bcf),
229            (None, None) => None,
230        };
231
232        let intervals = match restricted {
233            Some(ivs) => ivs,
234            None => intervals::header_to_intervals(&header, self.chunksize)?,
235        };
236
237        let tile = self.chunksize.max(1);
238
239        let estimated_total_chunks: usize = target_info
240            .iter()
241            .filter(|(len, _)| *len > 0)
242            .map(|(len, _)| (((*len - 1) / tile) + 1) as usize)
243            .sum();
244
245        let work = materialize_region_tasks(intervals, &target_info, tile, estimated_total_chunks);
246
247        if work.is_empty() {
248            info!("No intervals scheduled for processing; exiting early");
249            return Ok(());
250        }
251
252        let total_chunks = work.len();
253        let log_step = std::cmp::max(1, total_chunks / 10);
254        trace!(
255            "Scheduling {} region tasks (chunk size {}) across {} worker threads",
256            total_chunks,
257            tile,
258            self.threads
259        );
260
261        let processed_chunks = AtomicUsize::new(0);
262        let target_info = Arc::new(target_info);
263        let total_chunks_f = total_chunks as f64;
264
265        let worker_scale = (self.threads * 8).max(1);
266        let scheduling_granularity =
267            ((total_chunks + worker_scale.saturating_sub(1)) / worker_scale).max(1);
268
269        work.into_par_iter()
270            .with_min_len(1)
271            .with_max_len(scheduling_granularity)
272            .for_each_init(
273                || (sender.clone(), Arc::clone(&target_info)),
274                |(snd, target_info), task| {
275                    trace!(
276                        "Processing TID {} interval {}-{}",
277                        task.tid,
278                        task.start,
279                        task.stop
280                    );
281
282                    let results = self
283                        .processor
284                        .process_region(task.tid, task.start, task.stop);
285                    for item in results {
286                        if snd.send(item).is_err() {
287                            warn!("Channel closed; terminating region processing early");
288                            return;
289                        }
290                    }
291
292                    let completed = processed_chunks.fetch_add(1, Ordering::Relaxed) + 1;
293                    if completed == total_chunks || completed % log_step == 0 {
294                        let (_, tid_name) = &target_info[task.tid as usize];
295                        let percent = (completed as f64 / total_chunks_f) * 100.0;
296                        info!(
297                            "Processed {:.1}% ({} / {} chunks) – {}:{}-{}",
298                            percent, completed, total_chunks, tid_name, task.start, task.stop
299                        );
300                    }
301                },
302            );
303
304        Ok(())
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use bio::io::bed;
312    use proptest::prelude::*;
313    use rust_htslib::{bam, bcf};
314    use rust_lapper::{Interval, Lapper};
315    use smartstring::SmartString;
316    use std::collections::{HashMap, HashSet};
317    use tempfile::tempdir;
318
319    use crate::engine::position::pileup_position::PileupPosition;
320    use crate::engine::position::Position;
321
322    #[test]
323    fn region_task_materialization_respects_chunk_size() {
324        let intervals = vec![Lapper::new(vec![Interval {
325            start: 0,
326            stop: 120,
327            val: (),
328        }])];
329        let target_info = vec![(120_u32, "chr1".to_string())];
330
331        let tasks = super::materialize_region_tasks(intervals, &target_info, 50, 0);
332
333        assert_eq!(tasks.len(), 3);
334        assert_eq!(tasks[0].tid, 0);
335        assert_eq!(tasks[0].start, 0);
336        assert_eq!(tasks[0].stop, 50);
337        assert_eq!(tasks[1].start, 50);
338        assert_eq!(tasks[1].stop, 100);
339        assert_eq!(tasks[2].start, 100);
340        assert_eq!(tasks[2].stop, 120);
341    }
342
343    struct TestProcessor;
344
345    impl RegionProcessor for TestProcessor {
346        type P = PileupPosition;
347
348        fn process_region(&self, tid: u32, start: u32, stop: u32) -> Vec<Self::P> {
349            (start..stop)
350                .map(|pos| {
351                    let chr = SmartString::from(&tid.to_string());
352                    PileupPosition::new(chr, pos)
353                })
354                .collect()
355        }
356    }
357
358    prop_compose! {
359        fn arb_iv_start(max_iv: u64)(start in 0..max_iv/2) -> u64 { start }
360    }
361    prop_compose! {
362        fn arb_iv_size(max_iv: u64)(size in 1..max_iv/2) -> u64 { size }
363    }
364    prop_compose! {
365        fn arb_iv(max_iv: u64)(start in arb_iv_start(max_iv), size in arb_iv_size(max_iv)) -> Interval<u64, ()> {
366            Interval { start, stop: start + size, val: () }
367        }
368    }
369    fn arb_ivs(
370        max_iv: u64,
371        max_ivs: usize,
372    ) -> impl Strategy<Value = (Vec<Interval<u64, ()>>, u64, u64)> {
373        prop::collection::vec(arb_iv(max_iv), 0..max_ivs).prop_map(|vec| {
374            let mut furthest_right = 0;
375            let lapper = Lapper::new(vec.clone());
376            let expected = lapper.cov();
377            for iv in vec.iter() {
378                furthest_right = furthest_right.max(iv.stop);
379            }
380            (vec, expected, furthest_right)
381        })
382    }
383    fn arb_chrs(
384        max_chr: usize,
385        max_iv: u64,
386        max_ivs: usize,
387    ) -> impl Strategy<Value = Vec<(Vec<Interval<u64, ()>>, u64, u64)>> {
388        prop::collection::vec(arb_ivs(max_iv, max_ivs), 0..max_chr)
389    }
390
391    proptest! {
392        #[test]
393        fn interval_set(
394            chromosomes in arb_chrs(4, 10_000, 1_000),
395            chunksize in any::<u32>(),
396            cpus in 0..num_cpus::get(),
397            use_bed in any::<bool>(),
398            use_vcf in any::<bool>(),
399        ) {
400            let tempdir = tempdir().unwrap();
401            let bam_path = tempdir.path().join("test.bam");
402            let bed_path = tempdir.path().join("test.bed");
403            let vcf_path = tempdir.path().join("test.vcf");
404
405            let mut header = bam::header::Header::new();
406            for (i, chr) in chromosomes.iter().enumerate() {
407                let mut chr_rec = bam::header::HeaderRecord::new(b"SQ");
408                chr_rec.push_tag(b"SN", &i.to_string());
409                chr_rec.push_tag(b"LN", &chr.2.to_string());
410                header.push_record(&chr_rec);
411            }
412            let writer = bam::Writer::from_path(&bam_path, &header, bam::Format::Bam).unwrap();
413            drop(writer);
414            bam::index::build(&bam_path, None, bam::index::Type::Bai, 1).unwrap();
415
416            let mut bed_writer = bed::Writer::to_file(&bed_path).unwrap();
417            for (i, chr) in chromosomes.iter().enumerate() {
418                for iv in chr.0.iter() {
419                    let mut record = bed::Record::new();
420                    record.set_start(iv.start);
421                    record.set_end(iv.stop);
422                    record.set_chrom(&i.to_string());
423                    record.set_score(&0.to_string());
424                    bed_writer.write(&record).unwrap();
425                }
426            }
427            drop(bed_writer);
428
429            let mut vcf_truth = HashMap::new();
430            let mut vcf_header = bcf::header::Header::new();
431            for (i, chr) in chromosomes.iter().enumerate() {
432                vcf_header.push_record(
433                    format!("##contig=<ID={},length={}>", i, chr.2).as_bytes(),
434                );
435            }
436            let mut vcf_writer = bcf::Writer::from_path(&vcf_path, &vcf_header, true, bcf::Format::Vcf).unwrap();
437            let mut record = vcf_writer.empty_record();
438            for (i, chr) in chromosomes.iter().enumerate() {
439                record.set_rid(Some(i as u32));
440                let counter = vcf_truth.entry(i).or_insert(0);
441                let mut seen = HashSet::new();
442                for iv in chr.0.iter() {
443                    if seen.insert(iv.start) {
444                        *counter += 1;
445                    }
446                    record.set_pos(iv.start as i64);
447                    vcf_writer.write(&record).unwrap();
448                }
449            }
450            drop(vcf_writer);
451
452            let par_granges_runner = ParGranges::new(
453                bam_path,
454                None,
455                if use_bed { Some(bed_path) } else { None },
456                if use_vcf { Some(vcf_path) } else { None },
457                true,
458                Some((cpus + 1).max(1)),
459                Some(chunksize.max(1)),
460                Some(0.002),
461                TestProcessor,
462            );
463            let receiver = par_granges_runner.process().unwrap();
464            let mut chrom_counts = HashMap::new();
465            receiver.into_iter().for_each(|p: PileupPosition| {
466                *chrom_counts.entry(p.ref_seq.parse::<usize>().unwrap()).or_insert(0u64) += 1;
467            });
468
469            for (chrom, positions) in chrom_counts.iter() {
470                if use_bed && !use_vcf {
471                    prop_assert_eq!(chromosomes[*chrom].1, *positions);
472                } else if use_bed && use_vcf {
473                    prop_assert_eq!(chromosomes[*chrom].1, *positions);
474                } else if use_vcf && !use_bed {
475                    prop_assert_eq!(vcf_truth.get(chrom).unwrap(), positions);
476                } else {
477                    prop_assert_eq!(chromosomes[*chrom].2, *positions);
478                }
479            }
480        }
481    }
482}