1use std::fs::File;
2use std::io::{BufRead, BufReader, Read};
3use std::path::Path;
4
5use anyhow::{Context, Result, bail};
6use coitrees::{COITree, Interval, IntervalTree};
7use rand::Rng;
8
9use crate::sequence_dict::SequenceDictionary;
10
11const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
13
14pub struct TargetRegions {
23 trees: Vec<COITree<(), u32>>,
26 total_territory: u64,
28 per_contig_territory: Vec<u64>,
30 sorted_intervals: Vec<Vec<(u32, u32)>>,
33 dict: SequenceDictionary,
35}
36
37impl std::fmt::Debug for TargetRegions {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("TargetRegions")
40 .field("num_contigs", &self.trees.len())
41 .field("total_territory", &self.total_territory)
42 .field("dict", &self.dict)
43 .finish_non_exhaustive()
44 }
45}
46
47impl TargetRegions {
48 pub fn from_path(path: &Path, dict: &SequenceDictionary) -> Result<Self> {
57 let file = File::open(path)
58 .with_context(|| format!("Failed to open BED file: {}", path.display()))?;
59
60 let mut magic = [0u8; 2];
62 let is_gzipped = {
63 let mut peek = BufReader::new(file);
64 peek.read_exact(&mut magic).is_ok() && magic == GZIP_MAGIC
65 };
66
67 let file = File::open(path)?;
68 let reader: Box<dyn BufRead> = if is_gzipped {
69 Box::new(BufReader::new(flate2::read::MultiGzDecoder::new(file)))
70 } else {
71 Box::new(BufReader::new(file))
72 };
73
74 let mut intervals_by_contig: Vec<Vec<Interval<()>>> = vec![Vec::new(); dict.len()];
75 let mut raw_intervals_by_contig: Vec<Vec<(u32, u32)>> = vec![Vec::new(); dict.len()];
76 let mut total_territory: u64 = 0;
77 let mut per_contig_territory: Vec<u64> = vec![0; dict.len()];
78
79 for (line_num, line) in reader.lines().enumerate() {
80 let line =
81 line.with_context(|| format!("Failed to read line {} of BED file", line_num + 1))?;
82 let line = line.trim();
83 if line.is_empty()
84 || line.starts_with('#')
85 || line.starts_with("track ")
86 || line.starts_with("browser ")
87 {
88 continue;
89 }
90
91 let fields: Vec<&str> = line.split('\t').collect();
92 if fields.len() < 3 {
93 bail!("BED line {} has fewer than 3 fields: {line}", line_num + 1);
94 }
95
96 let contig = fields[0];
97 let start: u32 = fields[1].parse().with_context(|| {
98 format!("Invalid start coordinate on BED line {}: {}", line_num + 1, fields[1])
99 })?;
100 let end: u32 = fields[2].parse().with_context(|| {
101 format!("Invalid end coordinate on BED line {}: {}", line_num + 1, fields[2])
102 })?;
103
104 if start >= end {
105 bail!("BED line {} has start >= end: {start} >= {end}", line_num + 1);
106 }
107
108 let meta = dict.get_by_name(contig).ok_or_else(|| {
109 anyhow::anyhow!(
110 "BED line {} references unknown contig '{contig}'. \
111 Ensure the BED file matches the reference FASTA.",
112 line_num + 1
113 )
114 })?;
115
116 #[expect(clippy::cast_possible_truncation, reason = "contig lengths fit in u32")]
117 let contig_len = meta.length() as u32;
118 if end > contig_len {
119 bail!(
120 "BED line {} has end ({end}) > contig length ({contig_len}) for '{contig}'",
121 line_num + 1
122 );
123 }
124
125 #[expect(clippy::cast_possible_wrap, reason = "genomic coords < i32::MAX")]
127 let iv = Interval::new(start as i32, (end - 1) as i32, ());
128 intervals_by_contig[meta.index()].push(iv);
129 raw_intervals_by_contig[meta.index()].push((start, end));
130 let bases = u64::from(end - start);
131 total_territory += bases;
132 per_contig_territory[meta.index()] += bases;
133 }
134
135 let trees: Vec<COITree<(), u32>> = intervals_by_contig.iter().map(COITree::new).collect();
136
137 let sorted_intervals: Vec<Vec<(u32, u32)>> = raw_intervals_by_contig
139 .into_iter()
140 .map(|mut ivs| {
141 ivs.sort_unstable();
142 ivs
143 })
144 .collect();
145
146 Ok(Self {
147 trees,
148 total_territory,
149 per_contig_territory,
150 sorted_intervals,
151 dict: dict.clone(),
152 })
153 }
154
155 #[must_use]
161 pub fn total_territory(&self) -> u64 {
162 self.total_territory
163 }
164
165 #[must_use]
167 pub fn contig_territory(&self, contig_index: usize) -> u64 {
168 self.per_contig_territory.get(contig_index).copied().unwrap_or(0)
169 }
170
171 #[must_use]
174 #[expect(clippy::cast_possible_wrap, reason = "genomic coords < i32::MAX")]
175 pub fn overlaps(&self, contig_index: usize, start: u32, end: u32) -> bool {
176 self.trees
177 .get(contig_index)
178 .is_some_and(|tree| tree.query_count(start as i32, (end.saturating_sub(1)) as i32) > 0)
179 }
180
181 #[must_use]
184 pub fn contig_intervals(&self, contig_index: usize) -> &[(u32, u32)] {
185 self.sorted_intervals.get(contig_index).map_or(&[], Vec::as_slice)
186 }
187
188 #[must_use]
199 pub fn effective_territory(&self, fragment_mean: usize) -> u64 {
200 let l_minus_1 = fragment_mean.saturating_sub(1) as u64;
201 self.sorted_intervals
202 .iter()
203 .flat_map(|ivs| ivs.iter())
204 .map(|&(start, end)| u64::from(end - start) + l_minus_1)
205 .sum()
206 }
207
208 #[must_use]
212 pub fn contig_effective_territory(&self, contig_index: usize, fragment_mean: usize) -> u64 {
213 let l_minus_1 = fragment_mean.saturating_sub(1) as u64;
214 self.sorted_intervals.get(contig_index).map_or(0, |ivs| {
215 ivs.iter().map(|&(start, end)| u64::from(end - start) + l_minus_1).sum()
216 })
217 }
218
219 #[must_use]
221 pub fn dict(&self) -> &SequenceDictionary {
222 &self.dict
223 }
224}
225
226pub struct PaddedIntervalSampler {
237 intervals: Vec<(u32, u32)>,
239 cumulative: Vec<u64>,
242 total: u64,
244}
245
246impl PaddedIntervalSampler {
247 #[must_use]
255 pub fn new(intervals: &[(u32, u32)], pad: u32, contig_len: u32) -> Self {
256 if intervals.is_empty() {
257 return Self { intervals: Vec::new(), cumulative: Vec::new(), total: 0 };
258 }
259
260 let mut padded: Vec<(u32, u32)> = intervals
262 .iter()
263 .map(|&(start, end)| (start.saturating_sub(pad), end.min(contig_len)))
264 .collect();
265 padded.sort_unstable();
266
267 let mut merged: Vec<(u32, u32)> = Vec::with_capacity(padded.len());
269 for (start, end) in padded {
270 if let Some(last) = merged.last_mut()
271 && start <= last.1
272 {
273 last.1 = last.1.max(end);
274 continue;
275 }
276 merged.push((start, end));
277 }
278
279 let mut cumulative = Vec::with_capacity(merged.len());
281 let mut running = 0u64;
282 for &(start, end) in &merged {
283 running += u64::from(end - start);
284 cumulative.push(running);
285 }
286 let total = running;
287
288 Self { intervals: merged, cumulative, total }
289 }
290
291 pub fn sample_start(&self, rng: &mut impl Rng) -> Option<u32> {
295 if self.total == 0 {
296 return None;
297 }
298
299 let r = rng.random_range(0..self.total);
300 let idx = self.cumulative.partition_point(|&c| c <= r);
301 let (start, _end) = self.intervals[idx];
302 let base_before = if idx > 0 { self.cumulative[idx - 1] } else { 0 };
303 let offset = r - base_before;
304
305 #[expect(clippy::cast_possible_truncation, reason = "offset within interval fits u32")]
306 Some(start + offset as u32)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use std::io::Write;
313
314 use rand::SeedableRng;
315 use tempfile::NamedTempFile;
316
317 use super::*;
318 use crate::sequence_dict::SequenceMetadata;
319
320 fn test_dict() -> SequenceDictionary {
322 let sequences = vec![
324 SequenceMetadata::new(0, "chr1".to_string(), 10000),
325 SequenceMetadata::new(1, "chr2".to_string(), 5000),
326 ];
327 SequenceDictionary::from_entries(sequences)
328 }
329
330 fn write_bed(content: &str) -> NamedTempFile {
332 let mut f = NamedTempFile::new().unwrap();
333 f.write_all(content.as_bytes()).unwrap();
334 f.flush().unwrap();
335 f
336 }
337
338 #[test]
339 fn test_load_simple_bed() {
340 let dict = test_dict();
341 let bed = write_bed("chr1\t100\t200\nchr1\t300\t400\nchr2\t50\t150\n");
342 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
343
344 assert_eq!(regions.total_territory(), 300); }
346
347 #[test]
348 fn test_overlap_hit() {
349 let dict = test_dict();
350 let bed = write_bed("chr1\t100\t200\n");
351 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
352
353 assert!(regions.overlaps(0, 120, 180));
355 assert!(regions.overlaps(0, 50, 150));
357 assert!(regions.overlaps(0, 150, 250));
359 assert!(regions.overlaps(0, 0, 300));
361 }
362
363 #[test]
364 fn test_overlap_miss() {
365 let dict = test_dict();
366 let bed = write_bed("chr1\t100\t200\n");
367 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
368
369 assert!(!regions.overlaps(0, 0, 100));
371 assert!(!regions.overlaps(0, 200, 300));
373 assert!(!regions.overlaps(1, 100, 200));
375 }
376
377 #[test]
378 fn test_overlap_single_base() {
379 let dict = test_dict();
380 let bed = write_bed("chr1\t100\t200\n");
381 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
382
383 assert!(regions.overlaps(0, 99, 101));
385 assert!(regions.overlaps(0, 199, 201));
387 assert!(!regions.overlaps(0, 200, 201));
389 }
390
391 #[test]
392 fn test_skips_comments_and_blank_lines() {
393 let dict = test_dict();
394 let bed = write_bed("# header\n\nchr1\t100\t200\n\n");
395 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
396 assert_eq!(regions.total_territory(), 100);
397 }
398
399 #[test]
400 fn test_error_unknown_contig() {
401 let dict = test_dict();
402 let bed = write_bed("chrZ\t100\t200\n");
403 let result = TargetRegions::from_path(bed.path(), &dict);
404 assert!(result.is_err());
405 assert!(result.unwrap_err().to_string().contains("unknown contig"));
406 }
407
408 #[test]
409 fn test_error_start_gte_end() {
410 let dict = test_dict();
411 let bed = write_bed("chr1\t200\t100\n");
412 let result = TargetRegions::from_path(bed.path(), &dict);
413 assert!(result.is_err());
414 assert!(result.unwrap_err().to_string().contains("start >= end"));
415 }
416
417 #[test]
418 fn test_error_end_exceeds_contig_length() {
419 let dict = test_dict();
420 let bed = write_bed("chr1\t9000\t20000\n");
421 let result = TargetRegions::from_path(bed.path(), &dict);
422 assert!(result.is_err());
423 assert!(result.unwrap_err().to_string().contains("contig length"));
424 }
425
426 #[test]
427 fn test_effective_territory_single_target() {
428 let dict = test_dict();
429 let bed = write_bed("chr1\t100\t200\n"); let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
431
432 assert_eq!(regions.effective_territory(375), 474);
434 assert_eq!(regions.effective_territory(1), 100);
436 }
437
438 #[test]
439 fn test_effective_territory_multiple_targets() {
440 let dict = test_dict();
441 let bed = write_bed("chr1\t100\t200\nchr1\t500\t600\nchr2\t0\t50\n");
443 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
444
445 assert_eq!(regions.effective_territory(375), 1372);
447 }
448
449 #[test]
450 fn test_contig_effective_territory() {
451 let dict = test_dict();
452 let bed = write_bed("chr1\t100\t200\nchr1\t500\t600\nchr2\t0\t50\n");
453 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
454
455 assert_eq!(regions.contig_effective_territory(0, 375), 948);
457 assert_eq!(regions.contig_effective_territory(1, 375), 424);
459 }
460
461 #[test]
462 fn test_contig_intervals_returns_sorted_intervals() {
463 let dict = test_dict();
464 let bed = write_bed("chr1\t300\t400\nchr1\t100\t200\nchr2\t50\t150\n");
466 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
467
468 let chr1_ivs = regions.contig_intervals(0);
469 assert_eq!(chr1_ivs, &[(100, 200), (300, 400)]);
470
471 let chr2_ivs = regions.contig_intervals(1);
472 assert_eq!(chr2_ivs, &[(50, 150)]);
473 }
474
475 #[test]
476 fn test_contig_intervals_empty_contig() {
477 let dict = test_dict();
478 let bed = write_bed("chr1\t100\t200\n");
479 let regions = TargetRegions::from_path(bed.path(), &dict).unwrap();
480 assert!(regions.contig_intervals(1).is_empty());
481 }
482
483 #[test]
486 fn test_sampler_empty_intervals() {
487 let sampler = PaddedIntervalSampler::new(&[], 100, 10000);
488 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
489 assert!(sampler.sample_start(&mut rng).is_none());
490 }
491
492 #[test]
493 fn test_sampler_single_interval_no_pad() {
494 let sampler = PaddedIntervalSampler::new(&[(100, 200)], 0, 10000);
495 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
496
497 for _ in 0..1000 {
498 let pos = sampler.sample_start(&mut rng).unwrap();
499 assert!((100..200).contains(&pos), "pos {pos} not in [100, 200)");
500 }
501 }
502
503 #[test]
504 fn test_sampler_padding_extends_left() {
505 let sampler = PaddedIntervalSampler::new(&[(500, 600)], 200, 10000);
507 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
508
509 let mut min_seen = u32::MAX;
510 let mut max_seen = 0u32;
511 for _ in 0..10_000 {
512 let pos = sampler.sample_start(&mut rng).unwrap();
513 assert!((300..600).contains(&pos), "pos {pos} not in [300, 600)");
514 min_seen = min_seen.min(pos);
515 max_seen = max_seen.max(pos);
516 }
517
518 assert!(min_seen <= 310, "min_seen {min_seen} too high");
521 assert!(max_seen >= 590, "max_seen {max_seen} too low");
522 }
523
524 #[test]
525 fn test_sampler_padding_clamped_to_zero() {
526 let sampler = PaddedIntervalSampler::new(&[(50, 150)], 200, 10000);
528 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
529
530 for _ in 0..1000 {
531 let pos = sampler.sample_start(&mut rng).unwrap();
532 assert!(pos < 150, "pos {pos} not in [0, 150)");
533 }
534 }
535
536 #[test]
537 fn test_sampler_merges_overlapping_padded_intervals() {
538 let sampler = PaddedIntervalSampler::new(&[(200, 300), (350, 450)], 100, 10000);
541 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
542
543 for _ in 0..1000 {
544 let pos = sampler.sample_start(&mut rng).unwrap();
545 assert!((100..450).contains(&pos), "pos {pos} not in [100, 450)");
546 }
547 }
548
549 #[test]
550 fn test_sampler_keeps_disjoint_padded_intervals_separate() {
551 let sampler = PaddedIntervalSampler::new(&[(100, 150), (1000, 1050)], 50, 10000);
554 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
555
556 for _ in 0..1000 {
557 let pos = sampler.sample_start(&mut rng).unwrap();
558 let in_first = (50..150).contains(&pos);
559 let in_second = (950..1050).contains(&pos);
560 assert!(in_first || in_second, "pos {pos} not in either padded interval");
561 }
562 }
563
564 #[test]
565 fn test_sampler_samples_proportional_to_interval_size() {
566 let sampler = PaddedIntervalSampler::new(&[(1000, 2000), (5000, 5010)], 100, 10000);
572 let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
573
574 let mut count_first = 0u32;
575 let mut count_second = 0u32;
576 for _ in 0..11_000 {
577 let pos = sampler.sample_start(&mut rng).unwrap();
578 if (900..2000).contains(&pos) {
579 count_first += 1;
580 } else {
581 count_second += 1;
582 }
583 }
584
585 let ratio = f64::from(count_first) / f64::from(count_second);
586 assert!(
587 (8.0..12.0).contains(&ratio),
588 "ratio {ratio:.1} not near expected 10:1 (first={count_first}, second={count_second})"
589 );
590 }
591}