Skip to main content

holodeck_lib/
bed.rs

1use std::fs::File;
2use std::io::{BufRead, BufReader, Read};
3use std::path::Path;
4
5use anyhow::{Context, Result, bail};
6use coitrees::{COITree, Interval, IntervalTree};
7use rand::Rng;
8
9use crate::sequence_dict::SequenceDictionary;
10
11/// The gzip magic number (first two bytes of any gzip/bgzip file).
12const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
13
14/// A collection of genomic target regions loaded from a BED file, indexed for
15/// efficient overlap queries.
16///
17/// Internally stores one [`COITree`] per contig for O(log N) overlap detection.
18/// BED coordinates are 0-based half-open `[start, end)`, but coitrees uses
19/// end-inclusive intervals, so we store `[start, end-1]` internally.
20///
21/// `Debug` is implemented manually because `COITree` does not implement `Debug`.
22pub struct TargetRegions {
23    /// One interval tree per contig, indexed by the contig's position in the
24    /// sequence dictionary. Empty trees for contigs with no targets.
25    trees: Vec<COITree<(), u32>>,
26    /// Total bases covered by all target regions.
27    total_territory: u64,
28    /// Per-contig target territory (bases), indexed by contig position.
29    per_contig_territory: Vec<u64>,
30    /// Sorted intervals per contig in 0-based half-open coordinates `[start, end)`,
31    /// used to build padded sampling regions for fragment start position selection.
32    sorted_intervals: Vec<Vec<(u32, u32)>>,
33    /// Sequence dictionary for contig lookups.
34    dict: SequenceDictionary,
35}
36
37impl std::fmt::Debug for TargetRegions {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("TargetRegions")
40            .field("num_contigs", &self.trees.len())
41            .field("total_territory", &self.total_territory)
42            .field("dict", &self.dict)
43            .finish_non_exhaustive()
44    }
45}
46
47impl TargetRegions {
48    /// Load target regions from a BED file.
49    ///
50    /// The file may be plain text or gzipped. Coordinates are validated against
51    /// `dict` — unknown contigs or out-of-range coordinates cause an error.
52    ///
53    /// # Errors
54    /// Returns an error if the file cannot be read, contains invalid entries,
55    /// or references contigs not present in the dictionary.
56    pub fn from_path(path: &Path, dict: &SequenceDictionary) -> Result<Self> {
57        let file = File::open(path)
58            .with_context(|| format!("Failed to open BED file: {}", path.display()))?;
59
60        // Detect gzip by magic bytes, then re-open to read from the start.
61        let mut magic = [0u8; 2];
62        let is_gzipped = {
63            let mut peek = BufReader::new(file);
64            peek.read_exact(&mut magic).is_ok() && magic == GZIP_MAGIC
65        };
66
67        let file = File::open(path)?;
68        let reader: Box<dyn BufRead> = if is_gzipped {
69            Box::new(BufReader::new(flate2::read::MultiGzDecoder::new(file)))
70        } else {
71            Box::new(BufReader::new(file))
72        };
73
74        let mut intervals_by_contig: Vec<Vec<Interval<()>>> = vec![Vec::new(); dict.len()];
75        let mut raw_intervals_by_contig: Vec<Vec<(u32, u32)>> = vec![Vec::new(); dict.len()];
76        let mut total_territory: u64 = 0;
77        let mut per_contig_territory: Vec<u64> = vec![0; dict.len()];
78
79        for (line_num, line) in reader.lines().enumerate() {
80            let line =
81                line.with_context(|| format!("Failed to read line {} of BED file", line_num + 1))?;
82            let line = line.trim();
83            if line.is_empty()
84                || line.starts_with('#')
85                || line.starts_with("track ")
86                || line.starts_with("browser ")
87            {
88                continue;
89            }
90
91            let fields: Vec<&str> = line.split('\t').collect();
92            if fields.len() < 3 {
93                bail!("BED line {} has fewer than 3 fields: {line}", line_num + 1);
94            }
95
96            let contig = fields[0];
97            let start: u32 = fields[1].parse().with_context(|| {
98                format!("Invalid start coordinate on BED line {}: {}", line_num + 1, fields[1])
99            })?;
100            let end: u32 = fields[2].parse().with_context(|| {
101                format!("Invalid end coordinate on BED line {}: {}", line_num + 1, fields[2])
102            })?;
103
104            if start >= end {
105                bail!("BED line {} has start >= end: {start} >= {end}", line_num + 1);
106            }
107
108            let meta = dict.get_by_name(contig).ok_or_else(|| {
109                anyhow::anyhow!(
110                    "BED line {} references unknown contig '{contig}'. \
111                     Ensure the BED file matches the reference FASTA.",
112                    line_num + 1
113                )
114            })?;
115
116            #[expect(clippy::cast_possible_truncation, reason = "contig lengths fit in u32")]
117            let contig_len = meta.length() as u32;
118            if end > contig_len {
119                bail!(
120                    "BED line {} has end ({end}) > contig length ({contig_len}) for '{contig}'",
121                    line_num + 1
122                );
123            }
124
125            // coitrees uses end-inclusive: BED [start, end) → coitrees [start, end-1]
126            #[expect(clippy::cast_possible_wrap, reason = "genomic coords < i32::MAX")]
127            let iv = Interval::new(start as i32, (end - 1) as i32, ());
128            intervals_by_contig[meta.index()].push(iv);
129            raw_intervals_by_contig[meta.index()].push((start, end));
130            let bases = u64::from(end - start);
131            total_territory += bases;
132            per_contig_territory[meta.index()] += bases;
133        }
134
135        let trees: Vec<COITree<(), u32>> = intervals_by_contig.iter().map(COITree::new).collect();
136
137        // Sort raw intervals per contig for padded sampling region construction.
138        let sorted_intervals: Vec<Vec<(u32, u32)>> = raw_intervals_by_contig
139            .into_iter()
140            .map(|mut ivs| {
141                ivs.sort_unstable();
142                ivs
143            })
144            .collect();
145
146        Ok(Self {
147            trees,
148            total_territory,
149            per_contig_territory,
150            sorted_intervals,
151            dict: dict.clone(),
152        })
153    }
154
155    /// Return the total number of bases covered by all target regions.
156    ///
157    /// Note: if the BED file contains overlapping intervals, overlapping bases
158    /// are counted multiple times. Callers that need exact territory should
159    /// provide a non-overlapping BED.
160    #[must_use]
161    pub fn total_territory(&self) -> u64 {
162        self.total_territory
163    }
164
165    /// Return the target territory (in bases) for a specific contig.
166    #[must_use]
167    pub fn contig_territory(&self, contig_index: usize) -> u64 {
168        self.per_contig_territory.get(contig_index).copied().unwrap_or(0)
169    }
170
171    /// Check whether the interval `[start, end)` (0-based half-open) on the
172    /// given contig overlaps any target region.
173    #[must_use]
174    #[expect(clippy::cast_possible_wrap, reason = "genomic coords < i32::MAX")]
175    pub fn overlaps(&self, contig_index: usize, start: u32, end: u32) -> bool {
176        self.trees
177            .get(contig_index)
178            .is_some_and(|tree| tree.query_count(start as i32, (end.saturating_sub(1)) as i32) > 0)
179    }
180
181    /// Return the sorted target intervals for a contig in 0-based half-open
182    /// coordinates `[start, end)`.
183    #[must_use]
184    pub fn contig_intervals(&self, contig_index: usize) -> &[(u32, u32)] {
185        self.sorted_intervals.get(contig_index).map_or(&[], Vec::as_slice)
186    }
187
188    /// Compute the effective territory for coverage calculation, accounting for
189    /// the fact that fragments extend beyond target boundaries.
190    ///
191    /// For a target of width W, a fragment of length L placed uniformly to
192    /// overlap the target has an on-target fraction of `W / (W + L - 1)`.
193    /// The effective territory per target is therefore `W + L - 1` — the
194    /// catchment zone of fragment start positions — and the total effective
195    /// territory is the sum across all targets.  Using this as `effective_size`
196    /// in the standard coverage formula `N = C * effective_size / bases_per_read`
197    /// yields the correct number of reads for the desired mean target coverage.
198    #[must_use]
199    pub fn effective_territory(&self, fragment_mean: usize) -> u64 {
200        let l_minus_1 = fragment_mean.saturating_sub(1) as u64;
201        self.sorted_intervals
202            .iter()
203            .flat_map(|ivs| ivs.iter())
204            .map(|&(start, end)| u64::from(end - start) + l_minus_1)
205            .sum()
206    }
207
208    /// Compute the effective territory for a single contig.
209    ///
210    /// See [`effective_territory`](Self::effective_territory) for the rationale.
211    #[must_use]
212    pub fn contig_effective_territory(&self, contig_index: usize, fragment_mean: usize) -> u64 {
213        let l_minus_1 = fragment_mean.saturating_sub(1) as u64;
214        self.sorted_intervals.get(contig_index).map_or(0, |ivs| {
215            ivs.iter().map(|&(start, end)| u64::from(end - start) + l_minus_1).sum()
216        })
217    }
218
219    /// Return a reference to the underlying sequence dictionary.
220    #[must_use]
221    pub fn dict(&self) -> &SequenceDictionary {
222        &self.dict
223    }
224}
225
226/// A sampler that draws random fragment start positions from padded target regions.
227///
228/// Given a set of target intervals, each interval is padded on the left by a
229/// specified amount (to include fragments that start before a target but extend
230/// into it), overlapping padded intervals are merged, and start positions are
231/// sampled uniformly across the merged regions.  The resulting fragments should
232/// still be checked for overlap with the original unpadded targets — starts in
233/// the pad zone whose drawn fragment length is too short to reach the target
234/// will be rejected, but this is rare when the pad matches the expected max
235/// fragment length.
236pub struct PaddedIntervalSampler {
237    /// Merged padded intervals in 0-based half-open coordinates.
238    intervals: Vec<(u32, u32)>,
239    /// Cumulative territory sums for binary-search sampling.
240    /// `cumulative[j]` = total bases in intervals `0..=j`.
241    cumulative: Vec<u64>,
242    /// Total bases across all padded intervals.
243    total: u64,
244}
245
246impl PaddedIntervalSampler {
247    /// Build a sampler from sorted target intervals, padding each on the left.
248    ///
249    /// `pad` is typically the maximum expected fragment length (e.g.
250    /// `fragment_mean + 4 * fragment_stddev`), ensuring that fragments starting
251    /// before a target but overlapping it are represented in the sampling space.
252    /// Padded intervals are merged where they overlap and clamped to
253    /// `[0, contig_len)`.
254    #[must_use]
255    pub fn new(intervals: &[(u32, u32)], pad: u32, contig_len: u32) -> Self {
256        if intervals.is_empty() {
257            return Self { intervals: Vec::new(), cumulative: Vec::new(), total: 0 };
258        }
259
260        // Pad each interval on the left and clamp to [0, contig_len).
261        let mut padded: Vec<(u32, u32)> = intervals
262            .iter()
263            .map(|&(start, end)| (start.saturating_sub(pad), end.min(contig_len)))
264            .collect();
265        padded.sort_unstable();
266
267        // Merge overlapping or abutting intervals.
268        let mut merged: Vec<(u32, u32)> = Vec::with_capacity(padded.len());
269        for (start, end) in padded {
270            if let Some(last) = merged.last_mut()
271                && start <= last.1
272            {
273                last.1 = last.1.max(end);
274                continue;
275            }
276            merged.push((start, end));
277        }
278
279        // Build cumulative territory sums.
280        let mut cumulative = Vec::with_capacity(merged.len());
281        let mut running = 0u64;
282        for &(start, end) in &merged {
283            running += u64::from(end - start);
284            cumulative.push(running);
285        }
286        let total = running;
287
288        Self { intervals: merged, cumulative, total }
289    }
290
291    /// Sample a random start position uniformly from the padded intervals.
292    ///
293    /// Returns `None` if there are no intervals.
294    pub fn sample_start(&self, rng: &mut impl Rng) -> Option<u32> {
295        if self.total == 0 {
296            return None;
297        }
298
299        let r = rng.random_range(0..self.total);
300        let idx = self.cumulative.partition_point(|&c| c <= r);
301        let (start, _end) = self.intervals[idx];
302        let base_before = if idx > 0 { self.cumulative[idx - 1] } else { 0 };
303        let offset = r - base_before;
304
305        #[expect(clippy::cast_possible_truncation, reason = "offset within interval fits u32")]
306        Some(start + offset as u32)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use std::io::Write;
313
314    use rand::SeedableRng;
315    use tempfile::NamedTempFile;
316
317    use super::*;
318    use crate::sequence_dict::SequenceMetadata;
319
320    /// Build a dict for testing.
321    fn test_dict() -> SequenceDictionary {
322        // We need to construct a dict. Use the same test helper pattern.
323        let sequences = vec![
324            SequenceMetadata::new(0, "chr1".to_string(), 10000),
325            SequenceMetadata::new(1, "chr2".to_string(), 5000),
326        ];
327        SequenceDictionary::from_entries(sequences)
328    }
329
330    /// Write BED content to a temp file and return the path.
331    fn write_bed(content: &str) -> NamedTempFile {
332        let mut f = NamedTempFile::new().unwrap();
333        f.write_all(content.as_bytes()).unwrap();
334        f.flush().unwrap();
335        f
336    }
337
338    #[test]
339    fn test_load_simple_bed() {
340        let dict = test_dict();
341        let bed = write_bed("chr1\t100\t200\nchr1\t300\t400\nchr2\t50\t150\n");
342        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
343
344        assert_eq!(regions.total_territory(), 300); // 100 + 100 + 100
345    }
346
347    #[test]
348    fn test_overlap_hit() {
349        let dict = test_dict();
350        let bed = write_bed("chr1\t100\t200\n");
351        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
352
353        // Fragment fully within target
354        assert!(regions.overlaps(0, 120, 180));
355        // Fragment starts before, ends within
356        assert!(regions.overlaps(0, 50, 150));
357        // Fragment starts within, ends after
358        assert!(regions.overlaps(0, 150, 250));
359        // Fragment engulfs target
360        assert!(regions.overlaps(0, 0, 300));
361    }
362
363    #[test]
364    fn test_overlap_miss() {
365        let dict = test_dict();
366        let bed = write_bed("chr1\t100\t200\n");
367        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
368
369        // Fragment entirely before target
370        assert!(!regions.overlaps(0, 0, 100));
371        // Fragment entirely after target
372        assert!(!regions.overlaps(0, 200, 300));
373        // Wrong contig
374        assert!(!regions.overlaps(1, 100, 200));
375    }
376
377    #[test]
378    fn test_overlap_single_base() {
379        let dict = test_dict();
380        let bed = write_bed("chr1\t100\t200\n");
381        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
382
383        // Single-base overlap at target start
384        assert!(regions.overlaps(0, 99, 101));
385        // Single-base overlap at target end
386        assert!(regions.overlaps(0, 199, 201));
387        // Adjacent but not overlapping
388        assert!(!regions.overlaps(0, 200, 201));
389    }
390
391    #[test]
392    fn test_skips_comments_and_blank_lines() {
393        let dict = test_dict();
394        let bed = write_bed("# header\n\nchr1\t100\t200\n\n");
395        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
396        assert_eq!(regions.total_territory(), 100);
397    }
398
399    #[test]
400    fn test_error_unknown_contig() {
401        let dict = test_dict();
402        let bed = write_bed("chrZ\t100\t200\n");
403        let result = TargetRegions::from_path(bed.path(), &dict);
404        assert!(result.is_err());
405        assert!(result.unwrap_err().to_string().contains("unknown contig"));
406    }
407
408    #[test]
409    fn test_error_start_gte_end() {
410        let dict = test_dict();
411        let bed = write_bed("chr1\t200\t100\n");
412        let result = TargetRegions::from_path(bed.path(), &dict);
413        assert!(result.is_err());
414        assert!(result.unwrap_err().to_string().contains("start >= end"));
415    }
416
417    #[test]
418    fn test_error_end_exceeds_contig_length() {
419        let dict = test_dict();
420        let bed = write_bed("chr1\t9000\t20000\n");
421        let result = TargetRegions::from_path(bed.path(), &dict);
422        assert!(result.is_err());
423        assert!(result.unwrap_err().to_string().contains("contig length"));
424    }
425
426    #[test]
427    fn test_effective_territory_single_target() {
428        let dict = test_dict();
429        let bed = write_bed("chr1\t100\t200\n"); // 100bp target
430        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
431
432        // effective = W + L - 1 = 100 + 375 - 1 = 474
433        assert_eq!(regions.effective_territory(375), 474);
434        // With fragment_mean = 1: effective = W + 0 = 100
435        assert_eq!(regions.effective_territory(1), 100);
436    }
437
438    #[test]
439    fn test_effective_territory_multiple_targets() {
440        let dict = test_dict();
441        // Two 100bp targets on chr1, one 50bp target on chr2.
442        let bed = write_bed("chr1\t100\t200\nchr1\t500\t600\nchr2\t0\t50\n");
443        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
444
445        // 3 targets: sum of (W_i + L - 1) = (100+374) + (100+374) + (50+374) = 1372
446        assert_eq!(regions.effective_territory(375), 1372);
447    }
448
449    #[test]
450    fn test_contig_effective_territory() {
451        let dict = test_dict();
452        let bed = write_bed("chr1\t100\t200\nchr1\t500\t600\nchr2\t0\t50\n");
453        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
454
455        // chr1: (100+374) + (100+374) = 948
456        assert_eq!(regions.contig_effective_territory(0, 375), 948);
457        // chr2: (50+374) = 424
458        assert_eq!(regions.contig_effective_territory(1, 375), 424);
459    }
460
461    #[test]
462    fn test_contig_intervals_returns_sorted_intervals() {
463        let dict = test_dict();
464        // Intervals intentionally out of order in the BED file.
465        let bed = write_bed("chr1\t300\t400\nchr1\t100\t200\nchr2\t50\t150\n");
466        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
467
468        let chr1_ivs = regions.contig_intervals(0);
469        assert_eq!(chr1_ivs, &[(100, 200), (300, 400)]);
470
471        let chr2_ivs = regions.contig_intervals(1);
472        assert_eq!(chr2_ivs, &[(50, 150)]);
473    }
474
475    #[test]
476    fn test_contig_intervals_empty_contig() {
477        let dict = test_dict();
478        let bed = write_bed("chr1\t100\t200\n");
479        let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
480        assert!(regions.contig_intervals(1).is_empty());
481    }
482
483    // -- PaddedIntervalSampler tests --
484
485    #[test]
486    fn test_sampler_empty_intervals() {
487        let sampler = PaddedIntervalSampler::new(&[], 100, 10000);
488        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
489        assert!(sampler.sample_start(&mut rng).is_none());
490    }
491
492    #[test]
493    fn test_sampler_single_interval_no_pad() {
494        let sampler = PaddedIntervalSampler::new(&[(100, 200)], 0, 10000);
495        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
496
497        for _ in 0..1000 {
498            let pos = sampler.sample_start(&mut rng).unwrap();
499            assert!((100..200).contains(&pos), "pos {pos} not in [100, 200)");
500        }
501    }
502
503    #[test]
504    fn test_sampler_padding_extends_left() {
505        // Target at [500, 600), pad of 200 → sampling region [300, 600).
506        let sampler = PaddedIntervalSampler::new(&[(500, 600)], 200, 10000);
507        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
508
509        let mut min_seen = u32::MAX;
510        let mut max_seen = 0u32;
511        for _ in 0..10_000 {
512            let pos = sampler.sample_start(&mut rng).unwrap();
513            assert!((300..600).contains(&pos), "pos {pos} not in [300, 600)");
514            min_seen = min_seen.min(pos);
515            max_seen = max_seen.max(pos);
516        }
517
518        // With 10k samples across 300 positions, we should see positions near
519        // both edges.
520        assert!(min_seen <= 310, "min_seen {min_seen} too high");
521        assert!(max_seen >= 590, "max_seen {max_seen} too low");
522    }
523
524    #[test]
525    fn test_sampler_padding_clamped_to_zero() {
526        // Target near contig start: [50, 150), pad of 200 → [0, 150).
527        let sampler = PaddedIntervalSampler::new(&[(50, 150)], 200, 10000);
528        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
529
530        for _ in 0..1000 {
531            let pos = sampler.sample_start(&mut rng).unwrap();
532            assert!(pos < 150, "pos {pos} not in [0, 150)");
533        }
534    }
535
536    #[test]
537    fn test_sampler_merges_overlapping_padded_intervals() {
538        // Two intervals [200, 300) and [350, 450) with pad 100.
539        // Padded: [100, 300) and [250, 450) → merged: [100, 450).
540        let sampler = PaddedIntervalSampler::new(&[(200, 300), (350, 450)], 100, 10000);
541        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
542
543        for _ in 0..1000 {
544            let pos = sampler.sample_start(&mut rng).unwrap();
545            assert!((100..450).contains(&pos), "pos {pos} not in [100, 450)");
546        }
547    }
548
549    #[test]
550    fn test_sampler_keeps_disjoint_padded_intervals_separate() {
551        // Two intervals [100, 150) and [1000, 1050) with pad 50.
552        // Padded: [50, 150) and [950, 1050) — no overlap, should stay separate.
553        let sampler = PaddedIntervalSampler::new(&[(100, 150), (1000, 1050)], 50, 10000);
554        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
555
556        for _ in 0..1000 {
557            let pos = sampler.sample_start(&mut rng).unwrap();
558            let in_first = (50..150).contains(&pos);
559            let in_second = (950..1050).contains(&pos);
560            assert!(in_first || in_second, "pos {pos} not in either padded interval");
561        }
562    }
563
564    #[test]
565    fn test_sampler_samples_proportional_to_interval_size() {
566        // One large interval and one small interval.  Samples should be roughly
567        // proportional to padded sizes.
568        // [1000, 2000) pad 100 → [900, 2000) = 1100 bp
569        // [5000, 5010) pad 100 → [4900, 5010) = 110 bp
570        // Ratio should be ~10:1.
571        let sampler = PaddedIntervalSampler::new(&[(1000, 2000), (5000, 5010)], 100, 10000);
572        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
573
574        let mut count_first = 0u32;
575        let mut count_second = 0u32;
576        for _ in 0..11_000 {
577            let pos = sampler.sample_start(&mut rng).unwrap();
578            if (900..2000).contains(&pos) {
579                count_first += 1;
580            } else {
581                count_second += 1;
582            }
583        }
584
585        let ratio = f64::from(count_first) / f64::from(count_second);
586        assert!(
587            (8.0..12.0).contains(&ratio),
588            "ratio {ratio:.1} not near expected 10:1 (first={count_first}, second={count_second})"
589        );
590    }
591}