Skip to main content

perbase_lib/
par_granges.rs

1//! # ParGranges
2//!
3//! Iterates over chunked genomic regions in parallel.
4use anyhow::{Context, Result, anyhow};
5use bio::io::bed;
6use crossbeam::channel::{Receiver, bounded};
7use lazy_static::lazy_static;
8use log::*;
9use num_cpus;
10use rayon::prelude::*;
11use rust_htslib::{
12    bam::{HeaderView, IndexedReader, Read},
13    bcf::{Read as bcfRead, Reader},
14};
15use rust_lapper::{Interval, Lapper};
16use serde::Serialize;
17use std::{convert::TryInto, path::PathBuf, thread};
18
19const BYTES_INA_GIGABYTE: usize = 1024 * 1024 * 1024;
20
21/// A modifier to apply to the channel size formula that is (BYTES_INA_GIGABYTE * channel_size_modifier) * threads / size_of(R::P)
22/// 0.15 roughly corresponds to 1_000_000 PileupPosition objects per thread with some wiggle room.
23pub const CHANNEL_SIZE_MODIFIER: f64 = 0.15;
24
25/// The ideal number of basepairs each worker will receive. Total bp in memory at one time = `threads` * `chunksize`
26pub const CHUNKSIZE: u32 = 1_000_000;
27
28lazy_static! {
29    /// CHANNEL_SIZE_MODIFIER as a str
30    pub static ref CHANNEL_SIZE_MODIFIER_STR: String = CHANNEL_SIZE_MODIFIER.to_string();
31
32    /// CHUNKSIZE as a str
33    pub static ref CHUNKSIZE_STR: String = CHUNKSIZE.to_string();
34}
35
36/// RegionProcessor defines the methods that must be implemented to process a region
37pub trait RegionProcessor {
38    /// A vector of P make up the output of [`process_region`] and
39    /// are values associated with each position.
40    ///
41    /// [`process_region`]: #method.process_region
42    type P: 'static + Send + Sync + Serialize;
43
44    /// A function that takes the tid, start, and stop and returns something serializable.
45    /// Note, a common use of this function will be a `fetch` -> `pileup`. The pileup must
46    /// be bounds checked.
47    fn process_region(&self, tid: u32, start: u32, stop: u32) -> Vec<Self::P>;
48}
49
50/// ParGranges holds all the information and configuration needed to launch the
51/// [`ParGranges::process`].
52///
53/// [`ParGranges::process`]: #method.process
54#[derive(Debug)]
55pub struct ParGranges<R: 'static + RegionProcessor + Send + Sync> {
56    /// Path to an indexed BAM / CRAM file
57    reads: PathBuf,
58    /// Optional reference file for CRAM
59    ref_fasta: Option<PathBuf>,
60    /// Optional path to a BED file to restrict the regions iterated over
61    regions_bed: Option<PathBuf>,
62    /// Optional path to a BCF/VCF file to restrict the regions iterated over
63    regions_bcf: Option<PathBuf>,
64    /// If `regions_bed` and or `regions_bcf` is specified, and this is true, merge any overlapping regions to avoid duplicate output.
65    merge_regions: bool,
66    /// Number of threads this is allowed to use, uses all if None
67    threads: usize,
68    /// The ideal number of basepairs each worker will receive. Total bp in memory at one time = `threads` * `chunksize`
69    chunksize: u32,
70    /// A modifier to apply to the channel size formular that is (BYTES_INA_GIGABYTE * channel_size_modifier) * threads / size_of(R::P)
71    channel_size_modifier: f64,
72    /// The rayon threadpool to operate in
73    pool: rayon::ThreadPool,
74    /// The implementation of [RegionProcessor] that will be used to process regions
75    processor: R,
76}
77
78impl<R: RegionProcessor + Send + Sync> ParGranges<R> {
79    /// Create a ParIO object
80    ///
81    /// # Arguments
82    ///
83    /// * `reads`- path to an indexed BAM/CRAM
84    /// * `ref_fasta`- path to an indexed reference file for CRAM
85    /// * `regions_bed`- Optional BED file path restricting the regions to be examined
86    /// * `regions_bcf`- Optional BCF/VCF file path restricting the regions to be examined
87    /// * `merge_regions` - If `regions_bed` and or `regions_bcf` is specified, and this is true, merge any overlapping regions to avoid duplicate output.
88    /// * `threads`- Optional threads to restrict the number of threads this process will use, defaults to all
89    /// * `chunksize`- optional argument to change the default chunksize of 1_000_000. `chunksize` determines the number of bases
90    ///   each worker will get to work on at one time.
91    /// * `channel_size_modifier`- Optional argument to modify the default size ration of the channel that `R::P` is sent on.
92    ///   formula is: ((BYTES_INA_GIGABYTE * channel_size_modifier) * threads) / size_of(R::P)
93    /// * `processor`- Something that implements [`RegionProcessor`](RegionProcessor)
94    #[allow(clippy::too_many_arguments)]
95    pub fn new(
96        reads: PathBuf,
97        ref_fasta: Option<PathBuf>,
98        regions_bed: Option<PathBuf>,
99        regions_bcf: Option<PathBuf>,
100        merge_regions: bool,
101        threads: Option<usize>,
102        chunksize: Option<u32>,
103        channel_size_modifier: Option<f64>,
104        processor: R,
105    ) -> Self {
106        let threads = if let Some(threads) = threads {
107            threads
108        } else {
109            num_cpus::get()
110        };
111
112        // Keep two around for main thread and thread running the pool
113        let threads = std::cmp::max(threads.saturating_sub(2), 1);
114        let pool = rayon::ThreadPoolBuilder::new()
115            .num_threads(threads)
116            .build()
117            .unwrap();
118
119        info!("Using {} worker threads.", threads);
120        Self {
121            reads,
122            ref_fasta,
123            regions_bed,
124            regions_bcf,
125            merge_regions,
126            threads,
127            chunksize: chunksize.unwrap_or(CHUNKSIZE),
128            channel_size_modifier: channel_size_modifier.unwrap_or(CHANNEL_SIZE_MODIFIER),
129            pool,
130            processor,
131        }
132    }
133
134    /// Validate inputs before launching the processing threads to catch errors gracefully.
135    fn validate(&self) -> Result<()> {
136        let mut reader = IndexedReader::from_path(&self.reads)
137            .with_context(|| format!("Failed to open indexed BAM/CRAM: {:?}", self.reads))?;
138
139        if let Some(ref_fasta) = &self.ref_fasta {
140            reader
141                .set_reference(ref_fasta)
142                .with_context(|| format!("Failed to set reference: {:?}", ref_fasta))?;
143        }
144
145        if let Some(regions_bed) = &self.regions_bed
146            && !regions_bed.exists()
147        {
148            return Err(anyhow!("BED file does not exist: {:?}", regions_bed));
149        }
150
151        if let Some(regions_bcf) = &self.regions_bcf
152            && !regions_bcf.exists()
153        {
154            return Err(anyhow!("BCF/VCF file does not exist: {:?}", regions_bcf));
155        }
156        Ok(())
157    }
158
159    /// Process each region.
160    ///
161    /// This method splits the sequences in the BAM/CRAM header into `chunksize` * `self.threads` regions (aka 'super chunks').
162    /// It then queries that 'super chunk' against the intervals (either the BED file, or the whole genome broken up into `chunksize`
163    /// regions). The results of that query are then processed by a pool of workers that apply `process_region` to reach interval to
164    /// do perbase analysis on. The collected result for each region is then sent back over the returned `Receiver<R::P>` channel
165    /// for the caller to use. The results will be returned in order according to the order of the intervals used to drive this method.
166    ///
167    /// While one 'super chunk' is being worked on by all workers, the last 'super chunks' results are being printed to either to
168    /// a file or to STDOUT, in order.
169    ///
170    /// Note, a common use case of this will be to fetch a region and do a pileup. The bounds of bases being looked at should still be
171    /// checked since a fetch will pull all reads that overlap the region in question.
172    pub fn process(self) -> Result<Receiver<R::P>> {
173        let channel_size: usize = ((BYTES_INA_GIGABYTE as f64 * self.channel_size_modifier).floor()
174            as usize
175            / std::mem::size_of::<R::P>())
176            * self.threads;
177        info!(
178            "Creating channel of length {:?} (* 120 bytes to get mem)",
179            channel_size
180        );
181
182        self.validate()?;
183
184        let (snd, rxv) = bounded(channel_size);
185        thread::spawn(move || {
186            self.pool.install(|| {
187                info!("Reading from {:?}", self.reads);
188                let mut reader = IndexedReader::from_path(&self.reads).expect("Indexed BAM/CRAM");
189                // If passed add ref_fasta
190                if let Some(ref_fasta) = &self.ref_fasta {
191                    reader.set_reference(ref_fasta).expect("Set ref");
192                }
193                // Get a copy of the header
194                let header = reader.header().to_owned();
195
196                // Work out if we are restricted to a subset of sites
197                let bed_intervals = if let Some(regions_bed) = &self.regions_bed {
198                    Some(
199                        Self::bed_to_intervals(&header, regions_bed, self.merge_regions)
200                            .expect("Parsed BED to intervals"),
201                    )
202                } else {
203                    None
204                };
205                let bcf_intervals = if let Some(regions_bcf) = &self.regions_bcf {
206                    Some(
207                        Self::bcf_to_intervals(&header, regions_bcf, self.merge_regions)
208                            .expect("Parsed BCF/VCF to intervals"),
209                    )
210                } else {
211                    None
212                };
213                let restricted_ivs = match (bed_intervals, bcf_intervals) {
214                    (Some(bed_ivs), Some(bcf_ivs)) => {
215                        Some(Self::merge_intervals(bed_ivs, bcf_ivs, self.merge_regions))
216                    }
217                    (Some(bed_ivs), None) => Some(bed_ivs),
218                    (None, Some(bcf_ivs)) => Some(bcf_ivs),
219                    (None, None) => None,
220                };
221
222                let intervals = if let Some(restricted) = restricted_ivs {
223                    restricted
224                } else {
225                    Self::header_to_intervals(&header, self.chunksize)
226                        .expect("Parsed BAM/CRAM header to intervals")
227                };
228
229                // The number positions to try to process in one batch
230                let serial_step_size = self.chunksize.saturating_mul(self.threads as u32); // aka superchunk
231                for (tid, intervals) in intervals.into_iter().enumerate() {
232                    let tid: u32 = tid as u32;
233                    let tid_end: u32 = header.target_len(tid).unwrap().try_into().unwrap();
234                    info!("Processing TID {}:0-{}", tid, tid_end);
235                    // Result holds the processed positions to be sent to writer
236                    let mut result = vec![];
237                    for chunk_start in (0..tid_end).step_by(serial_step_size as usize) {
238                        let tid_name = std::str::from_utf8(header.tid2name(tid)).unwrap();
239                        let chunk_end = std::cmp::min(chunk_start + serial_step_size, tid_end);
240                        trace!(
241                            "Batch Processing {}:{}-{}",
242                            tid_name, chunk_start, chunk_end
243                        );
244                        let (r, _) = rayon::join(
245                            || {
246                                // Must be a vec so that par_iter works and results stay in order
247                                let ivs: Vec<Interval<u32, ()>> =
248                                    Lapper::<u32, ()>::find(&intervals, chunk_start, chunk_end)
249                                        // Truncate intervals that extend forward or backward of chunk in question
250                                        .map(|iv| Interval {
251                                            start: std::cmp::max(iv.start, chunk_start),
252                                            stop: std::cmp::min(iv.stop, chunk_end),
253                                            val: (),
254                                        })
255                                        .collect();
256                                ivs.into_par_iter()
257                                    .flat_map(|iv| {
258                                        trace!("Processing {}:{}-{}", tid_name, iv.start, iv.stop);
259                                        self.processor.process_region(tid, iv.start, iv.stop)
260                                    })
261                                    .collect()
262                            },
263                            || {
264                                result.into_iter().for_each(|p| {
265                                    snd.send(p).expect("Sent a serializable to writer")
266                                })
267                            },
268                        );
269                        result = r;
270                    }
271                    // Send final set of results
272                    result
273                        .into_iter()
274                        .for_each(|p| snd.send(p).expect("Sent a serializable to writer"));
275                }
276            });
277        });
278        Ok(rxv)
279    }
280
281    // Convert the header into intervals of equally sized chunks. The last interval may be short.
282    fn header_to_intervals(header: &HeaderView, chunksize: u32) -> Result<Vec<Lapper<u32, ()>>> {
283        let mut intervals = vec![vec![]; header.target_count() as usize];
284        for tid in 0..(header.target_count()) {
285            let tid_len: u32 = header.target_len(tid).unwrap().try_into().unwrap();
286            for start in (0..tid_len).step_by(chunksize as usize) {
287                let stop = std::cmp::min(start + chunksize, tid_len);
288                intervals[tid as usize].push(Interval {
289                    start,
290                    stop,
291                    val: (),
292                });
293            }
294        }
295        Ok(intervals.into_iter().map(Lapper::new).collect())
296    }
297
298    /// Read a bed file into a vector of lappers with the index representing the TID
299    /// if `merge' is true then any overlapping intervals in the sets will be merged.
300    // TODO add a proper error message
301    fn bed_to_intervals(
302        header: &HeaderView,
303        bed_file: &PathBuf,
304        merge: bool,
305    ) -> Result<Vec<Lapper<u32, ()>>> {
306        let mut bed_reader = bed::Reader::from_file(bed_file)?;
307        let mut intervals = vec![vec![]; header.target_count() as usize];
308        for (i, record) in bed_reader.records().enumerate() {
309            let record = record?;
310            let tid = header
311                .tid(record.chrom().as_bytes())
312                .expect("Chromosome not found in BAM/CRAM header");
313            let start = record
314                .start()
315                .try_into()
316                .with_context(|| format!("BED record {} is invalid: unable to parse start", i))?;
317            let stop = record
318                .end()
319                .try_into()
320                .with_context(|| format!("BED record {} is invalid: unable to parse stop", i))?;
321            if stop < start {
322                return Err(anyhow!("BED record {} is invalid: stop < start", i));
323            }
324            intervals[tid as usize].push(Interval {
325                start,
326                stop,
327                val: (),
328            });
329        }
330
331        Ok(intervals
332            .into_iter()
333            .map(|ivs| {
334                let mut lapper = Lapper::new(ivs);
335                if merge {
336                    lapper.merge_overlaps();
337                }
338                lapper
339            })
340            .collect())
341    }
342
343    /// Read a BCF/VCF file into a vector of lappers with index representing the TID
344    /// if `merge' is true then any overlapping intervals in the sets will be merged.
345    fn bcf_to_intervals(
346        header: &HeaderView,
347        bcf_file: &PathBuf,
348        merge: bool,
349    ) -> Result<Vec<Lapper<u32, ()>>> {
350        let mut bcf_reader = Reader::from_path(bcf_file).expect("Error opening BCF/VCF file.");
351        let bcf_header_reader = Reader::from_path(bcf_file).expect("Error opening BCF/VCF file.");
352        let bcf_header = bcf_header_reader.header();
353        let mut intervals = vec![vec![]; header.target_count() as usize];
354        // TODO: validate the headers against eachother
355        for record in bcf_reader.records() {
356            let record = record?;
357            let record_rid = bcf_header.rid2name(record.rid().unwrap()).unwrap();
358            let tid = header
359                .tid(record_rid)
360                .expect("Chromosome not found in BAM/CRAM header");
361            let pos: u32 = record
362                .pos()
363                .try_into()
364                .expect("Got a negative value for pos");
365            intervals[tid as usize].push(Interval {
366                start: pos,
367                stop: pos + 1,
368                val: (),
369            });
370        }
371
372        Ok(intervals
373            .into_iter()
374            .map(|ivs| {
375                let mut lapper = Lapper::new(ivs);
376                if merge {
377                    lapper.merge_overlaps();
378                }
379                lapper
380            })
381            .collect())
382    }
383
384    /// Merge two sets of restriction intervals together
385    /// if `merge' is true then any overlapping intervals in the sets will be merged.
386    fn merge_intervals(
387        a_ivs: Vec<Lapper<u32, ()>>,
388        b_ivs: Vec<Lapper<u32, ()>>,
389        merge: bool,
390    ) -> Vec<Lapper<u32, ()>> {
391        let mut intervals = vec![vec![]; a_ivs.len()];
392        for (i, (a_lapper, b_lapper)) in a_ivs.into_iter().zip(b_ivs.into_iter()).enumerate() {
393            intervals[i] = a_lapper.into_iter().chain(b_lapper.into_iter()).collect();
394        }
395        intervals
396            .into_iter()
397            .map(|ivs| {
398                let mut lapper = Lapper::new(ivs);
399                if merge {
400                    lapper.merge_overlaps();
401                }
402                lapper
403            })
404            .collect()
405    }
406}
407
408#[cfg(test)]
409mod test {
410    use super::*;
411    use bio::io::bed;
412    use num_cpus;
413    use proptest::prelude::*;
414    use rust_htslib::{bam, bcf};
415    use rust_lapper::{Interval, Lapper};
416    use std::collections::{HashMap, HashSet};
417    use tempfile::tempdir;
418    // The purpose of these tests is to demonstrate that positions are covered once under a variety of circumstances
419
420    prop_compose! {
421        fn arb_iv_start(max_iv: u64)(start in 0..max_iv/2) -> u64 { start }
422    }
423    prop_compose! {
424        fn arb_iv_size(max_iv: u64)(size in 1..max_iv/2) -> u64 { size }
425    }
426    prop_compose! {
427        // Create an arbitrary interval where the min size == max_iv / 2
428        fn arb_iv(max_iv: u64)(start in arb_iv_start(max_iv), size in arb_iv_size(max_iv)) -> Interval<u64, ()> {
429            Interval {start, stop: start + size, val: ()}
430        }
431    }
432    // Create an arbitrary number of intervals along with the expected number of positions they cover
433    fn arb_ivs(
434        max_iv: u64,    // max iv size
435        max_ivs: usize, // max number of intervals
436    ) -> impl Strategy<Value = (Vec<Interval<u64, ()>>, u64, u64)> {
437        prop::collection::vec(arb_iv(max_iv), 0..max_ivs).prop_map(|vec| {
438            let mut furthest_right = 0;
439            let lapper = Lapper::new(vec.clone());
440            let expected = lapper.cov();
441            for iv in vec.iter() {
442                if iv.stop > furthest_right {
443                    furthest_right = iv.stop;
444                }
445            }
446            (vec, expected, furthest_right)
447        })
448    }
449    // Create arbitrary number of contigs with arbitrary intervals each
450    fn arb_chrs(
451        max_chr: usize, // number of chromosomes to use
452        max_iv: u64,    // max interval size
453        max_ivs: usize, // max number of intervals
454    ) -> impl Strategy<Value = Vec<(Vec<Interval<u64, ()>>, u64, u64)>> {
455        prop::collection::vec(arb_ivs(max_iv, max_ivs), 0..max_chr)
456    }
457    // An empty BAM with correct header
458    // A BED file with the randomly generated intervals (with expected number of positions)
459    // proptest generate random chunksize, cpus
460    proptest! {
461        #[test]
462        // add random chunksize and random cpus
463        // NB: using any larger numbers for this tends to blow up the test runtime
464        fn interval_set(chromosomes in arb_chrs(4, 10_000, 1_000), chunksize in any::<u32>(), cpus in 0..num_cpus::get(), use_bed in any::<bool>(), use_vcf in any::<bool>()) {
465            let tempdir = tempdir().unwrap();
466            let bam_path = tempdir.path().join("test.bam");
467            let bed_path = tempdir.path().join("test.bed");
468            let vcf_path = tempdir.path().join("test.vcf");
469
470            // Build a BAM
471            let mut header = bam::header::Header::new();
472            for (i,chr) in chromosomes.iter().enumerate() {
473                let mut chr_rec = bam::header::HeaderRecord::new(b"SQ");
474                chr_rec.push_tag(b"SN", &i.to_string());
475                chr_rec.push_tag(b"LN", &chr.2.to_string()); // set len as max observed
476                header.push_record(&chr_rec);
477            }
478            let writer = bam::Writer::from_path(&bam_path, &header, bam::Format::Bam).expect("Opened test.bam for writing");
479            drop(writer); // force flush the writer so the header info is written
480            bam::index::build(&bam_path, None, bam::index::Type::Bai, 1).unwrap();
481
482            // Build a bed
483            let mut writer = bed::Writer::to_file(&bed_path).expect("Opened test.bed for writing");
484            for (i, chr) in chromosomes.iter().enumerate() {
485                for iv in chr.0.iter() {
486                    let mut record = bed::Record::new();
487                    record.set_start(iv.start);
488                    record.set_end(iv.stop);
489                    record.set_chrom(&i.to_string());
490                    record.set_score(&0.to_string());
491                    writer.write(&record).expect("Wrote to test.bed");
492                }
493            }
494            drop(writer); // force flush
495
496            // Build a VCF file
497            let mut vcf_truth = HashMap::new();
498            let mut header = bcf::header::Header::new();
499            for (i,chr) in chromosomes.iter().enumerate() {
500                header.push_record(format!("##contig=<ID={},length={}>", &i.to_string(), &chr.2.to_string()).as_bytes());
501            }
502            let mut writer = bcf::Writer::from_path(&vcf_path, &header, true, bcf::Format::Vcf).expect("Failed to open test.vcf for writing");
503            let mut record = writer.empty_record();
504            for (i, chr) in chromosomes.iter().enumerate() {
505                record.set_rid(Some(i as u32));
506                let counter = vcf_truth.entry(i).or_insert(0);
507                let mut seen = HashSet::new();
508                for iv in chr.0.iter() {
509                    if !seen.contains(&iv.start) {
510                        *counter += 1;
511                        seen.insert(iv.start);
512                    }
513                    record.set_pos(iv.start as i64);
514                    writer.write(&record).expect("Failed to write to test.vcf")
515                }
516            }
517
518            drop(writer); // force flush
519            // Create the processor with a dumb impl of processing that just returns positions with no counting
520            let test_processor = TestProcessor {};
521            let par_granges_runner = ParGranges::new(
522                bam_path,
523                None,
524                if use_bed { Some(bed_path) } else { None }, // do one with regions
525                if use_vcf { Some(vcf_path) } else { None }, // do one with vcf regions
526                true,
527                Some(cpus),
528                Some(chunksize),
529                Some(0.002),
530                test_processor
531            );
532            let receiver = par_granges_runner.process().expect("Launch ParGranges Process");
533            let mut chrom_counts = HashMap::new();
534            receiver.into_iter().for_each(|p: PileupPosition| {
535                let positions = chrom_counts.entry(p.ref_seq.parse::<usize>().expect("parsed chr")).or_insert(0u64);
536                *positions += 1
537            });
538
539            // Validate that for each chr we get the expected number of bases
540            for (chrom, positions) in chrom_counts.iter() {
541                if use_bed  && !use_vcf {
542                    // if this was with bed, should be equal to .1
543                    prop_assert_eq!(chromosomes[*chrom].1, *positions, "chr: {}, expected: {}, found: {}", chrom, chromosomes[*chrom].1, positions);
544                } else if use_bed && use_vcf {
545                    // if this was with bed, should be equal to .1, bed restrictions and vcf restrctions should overlap
546                    prop_assert_eq!(chromosomes[*chrom].1, *positions, "chr: {}, expected: {}, found: {}", chrom, chromosomes[*chrom].1, positions);
547                } else if use_vcf && !use_bed {
548                    // total positions should be equal to the number of records for that chr in the vcf
549                    prop_assert_eq!(vcf_truth.get(chrom).unwrap(), positions, "chr: {}, expected: {}, found: {}", chrom, chromosomes[*chrom].1, positions);
550                } else {
551                    // if this was bam only, should be equal to rightmost postion
552                    prop_assert_eq!(chromosomes[*chrom].2, *positions, "chr: {}, expected: {}, found: {}", chrom, chromosomes[*chrom].2, positions);
553                }
554            }
555
556        }
557    }
558
559    use crate::position::{Position, pileup_position::PileupPosition};
560    use smartstring::SmartString;
561    struct TestProcessor {}
562    impl RegionProcessor for TestProcessor {
563        type P = PileupPosition;
564
565        fn process_region(&self, tid: u32, start: u32, stop: u32) -> Vec<Self::P> {
566            let mut results = vec![];
567            for i in start..stop {
568                let chr = SmartString::from(&tid.to_string());
569                let pos = PileupPosition::new(chr, i);
570                results.push(pos);
571            }
572            results
573        }
574    }
575}