1use std::{collections::HashMap, fmt::Debug};
37
38use gtars_core::models::{Interval, RegionSet};
39use num_traits::{PrimInt, Unsigned};
40use thiserror::Error;
41
42use crate::{AIList, Bits, Overlapper, OverlapperType};
43
44#[derive(Debug, Error)]
46pub enum MultiChromOverlapperError {
47 #[error("Error parsing region: {0}")]
49 RegionParsingError(String),
50 #[error("Error converting interval coordinates to u32: start={0}, end={1}")]
52 CoordinateConversionError(String, String),
53}
54
55pub struct MultiChromOverlapper<I, T> {
65 index_maps: HashMap<String, Box<dyn Overlapper<I, T>>>,
66 #[allow(dead_code)]
67 overlapper_type: OverlapperType,
68}
69
70pub struct IterFindOverlaps<'a, 'b, I, T>
99where
100 I: PrimInt + Unsigned + Send + Sync + Debug,
101 T: Eq + Clone + Send + Sync + Debug,
102{
103 inner: &'a HashMap<String, Box<dyn Overlapper<I, T>>>,
104 rs: &'b RegionSet,
105 region_idx: usize,
106 current_chr: Option<String>,
107 current_iter: Option<Box<dyn Iterator<Item = &'a Interval<I, T>> + 'a>>,
108}
109
110impl<'a, 'b, I, T> Iterator for IterFindOverlaps<'a, 'b, I, T>
111where
112 I: PrimInt + Unsigned + Send + Sync + Debug,
113 T: Eq + Clone + Send + Sync + Debug,
114{
115 type Item = (String, &'a Interval<I, T>);
116
117 fn next(&mut self) -> Option<Self::Item> {
118 loop {
119 #[allow(clippy::collapsible_if)]
121 if let Some(ref mut iter) = self.current_iter {
122 if let Some(interval) = iter.next() {
123 return Some((self.current_chr.as_ref().unwrap().clone(), interval));
124 }
125 }
126
127 if self.region_idx >= self.rs.regions.len() {
129 return None;
132 }
133
134 let region = &self.rs.regions[self.region_idx];
135 self.region_idx += 1;
136
137 if let Some(lapper) = self.inner.get(®ion.chr) {
139 if let (Some(start), Some(end)) = (I::from(region.start), I::from(region.end)) {
141 self.current_chr = Some(region.chr.clone());
142 self.current_iter = Some(lapper.find_iter(start, end));
143 } else {
145 panic!(
149 "Type conversion error: cannot convert Region coordinates to index type. \
150 Region: {}:{}-{}, expected type: {}",
151 region.chr,
152 region.start,
153 region.end,
154 std::any::type_name::<I>()
155 );
156 }
157 } else {
158 continue;
160 }
161 }
162 }
163}
164
165impl<I, T> MultiChromOverlapper<I, T>
166where
167 I: PrimInt + Unsigned + Send + Sync + Debug,
168 T: Eq + Clone + Send + Sync + Debug,
169{
170 pub fn find_overlaps_iter<'a, 'b>(
174 &'a self,
175 rs: &'b RegionSet,
176 ) -> IterFindOverlaps<'a, 'b, I, T> {
177 IterFindOverlaps {
178 inner: &self.index_maps,
179 rs,
180 region_idx: 0,
181 current_chr: None,
182 current_iter: None,
183 }
184 }
185
186 pub fn find_overlaps(&self, rs: &RegionSet) -> Vec<(String, Interval<I, T>)> {
191 self.find_overlaps_iter(rs)
192 .map(|(chr, interval)| (chr, interval.clone()))
193 .collect()
194 }
195}
196
197pub trait IntoMultiChromOverlapper<I, T>
221where
222 I: PrimInt + Unsigned + Send + Sync,
223 T: Eq + Clone + Send + Sync,
224{
225 fn into_multi_chrom_overlapper(
235 self,
236 overlapper_type: OverlapperType,
237 ) -> MultiChromOverlapper<I, T>;
238}
239
240impl IntoMultiChromOverlapper<u32, Option<String>> for RegionSet {
241 fn into_multi_chrom_overlapper(
242 self,
243 overlapper_type: OverlapperType,
244 ) -> MultiChromOverlapper<u32, Option<String>> {
245 let mut core: HashMap<String, Box<dyn Overlapper<u32, Option<String>>>> =
247 HashMap::default();
248 let mut intervals: HashMap<String, Vec<Interval<u32, Option<String>>>> = HashMap::default();
249
250 for region in self.regions.into_iter() {
252 let interval = Interval {
254 start: region.start,
255 end: region.end,
256 val: region.rest,
257 };
258
259 let chr_intervals = intervals.entry(region.chr.clone()).or_default();
261
262 chr_intervals.push(interval);
264 }
265
266 for (chr, chr_intervals) in intervals.into_iter() {
268 let lapper: Box<dyn Overlapper<u32, Option<String>>> = match overlapper_type {
269 OverlapperType::Bits => Box::new(Bits::build(chr_intervals)),
270 OverlapperType::AIList => Box::new(AIList::build(chr_intervals)),
271 };
272 core.insert(chr.to_string(), lapper);
273 }
274
275 MultiChromOverlapper {
276 index_maps: core,
277 overlapper_type,
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use gtars_core::models::Region;
286 use pretty_assertions::assert_eq;
287 use rstest::*;
288
289 #[rstest]
290 #[case(OverlapperType::AIList)]
291 #[case(OverlapperType::Bits)]
292 fn test_basic_overlaps(#[case] overlapper_type: OverlapperType) {
293 let regions = vec![
294 Region {
295 chr: "chr1".to_string(),
296 start: 100,
297 end: 200,
298 rest: None,
299 },
300 Region {
301 chr: "chr1".to_string(),
302 start: 300,
303 end: 400,
304 rest: None,
305 },
306 Region {
307 chr: "chr1".to_string(),
308 start: 600,
309 end: 800,
310 rest: None,
311 },
312 ];
313 let rs = RegionSet::from(regions);
314 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
315
316 let query = RegionSet::from(vec![Region {
317 chr: "chr1".to_string(),
318 start: 110,
319 end: 210,
320 rest: None,
321 }]);
322
323 let hits = gi.find_overlaps(&query);
324 assert_eq!(hits.len(), 1);
325 assert_eq!(hits[0].0, "chr1");
326 assert_eq!(hits[0].1.start, 100);
327 assert_eq!(hits[0].1.end, 200);
328 }
329
330 #[rstest]
331 #[case(OverlapperType::AIList)]
332 #[case(OverlapperType::Bits)]
333 fn test_multiple_overlaps_single_query(#[case] overlapper_type: OverlapperType) {
334 let regions = vec![
335 Region {
336 chr: "chr1".to_string(),
337 start: 100,
338 end: 200,
339 rest: None,
340 },
341 Region {
342 chr: "chr1".to_string(),
343 start: 150,
344 end: 250,
345 rest: None,
346 },
347 Region {
348 chr: "chr1".to_string(),
349 start: 180,
350 end: 300,
351 rest: None,
352 },
353 ];
354 let rs = RegionSet::from(regions);
355 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
356
357 let query = RegionSet::from(vec![Region {
358 chr: "chr1".to_string(),
359 start: 160,
360 end: 190,
361 rest: None,
362 }]);
363
364 let hits = gi.find_overlaps(&query);
365 assert_eq!(hits.len(), 3);
366 }
367
368 #[rstest]
369 #[case(OverlapperType::AIList)]
370 #[case(OverlapperType::Bits)]
371 fn test_no_overlaps(#[case] overlapper_type: OverlapperType) {
372 let regions = vec![
373 Region {
374 chr: "chr1".to_string(),
375 start: 100,
376 end: 200,
377 rest: None,
378 },
379 Region {
380 chr: "chr1".to_string(),
381 start: 300,
382 end: 400,
383 rest: None,
384 },
385 ];
386 let rs = RegionSet::from(regions);
387 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
388
389 let query = RegionSet::from(vec![Region {
390 chr: "chr1".to_string(),
391 start: 500,
392 end: 600,
393 rest: None,
394 }]);
395
396 let hits = gi.find_overlaps(&query);
397 assert_eq!(hits.len(), 0);
398 }
399
400 #[rstest]
401 #[case(OverlapperType::AIList)]
402 #[case(OverlapperType::Bits)]
403 fn test_multiple_chromosomes(#[case] overlapper_type: OverlapperType) {
404 let regions = vec![
405 Region {
406 chr: "chr1".to_string(),
407 start: 100,
408 end: 200,
409 rest: None,
410 },
411 Region {
412 chr: "chr2".to_string(),
413 start: 300,
414 end: 400,
415 rest: None,
416 },
417 Region {
418 chr: "chr3".to_string(),
419 start: 500,
420 end: 600,
421 rest: None,
422 },
423 ];
424 let rs = RegionSet::from(regions);
425 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
426
427 let query = RegionSet::from(vec![
428 Region {
429 chr: "chr1".to_string(),
430 start: 150,
431 end: 250,
432 rest: None,
433 },
434 Region {
435 chr: "chr2".to_string(),
436 start: 350,
437 end: 450,
438 rest: None,
439 },
440 ]);
441
442 let hits = gi.find_overlaps(&query);
443 assert_eq!(hits.len(), 2);
444
445 let chr1_hits: Vec<_> = hits.iter().filter(|(chr, _)| chr == "chr1").collect();
446 let chr2_hits: Vec<_> = hits.iter().filter(|(chr, _)| chr == "chr2").collect();
447
448 assert_eq!(chr1_hits.len(), 1);
449 assert_eq!(chr2_hits.len(), 1);
450 }
451
452 #[rstest]
453 #[case(OverlapperType::AIList)]
454 #[case(OverlapperType::Bits)]
455 fn test_exact_boundary_overlaps(#[case] overlapper_type: OverlapperType) {
456 let regions = vec![Region {
457 chr: "chr1".to_string(),
458 start: 100,
459 end: 200,
460 rest: None,
461 }];
462 let rs = RegionSet::from(regions);
463 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
464
465 let query = RegionSet::from(vec![Region {
467 chr: "chr1".to_string(),
468 start: 200,
469 end: 300,
470 rest: None,
471 }]);
472
473 let hits = gi.find_overlaps(&query);
474 assert_eq!(hits.len(), 0);
476 }
477
478 #[rstest]
479 #[case(OverlapperType::AIList)]
480 #[case(OverlapperType::Bits)]
481 fn test_empty_query(#[case] overlapper_type: OverlapperType) {
482 let regions = vec![Region {
483 chr: "chr1".to_string(),
484 start: 100,
485 end: 200,
486 rest: None,
487 }];
488 let rs = RegionSet::from(regions);
489 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
490
491 let query = RegionSet::from(vec![]);
492 let hits = gi.find_overlaps(&query);
493 assert_eq!(hits.len(), 0);
494 }
495
496 #[rstest]
497 #[case(OverlapperType::AIList)]
498 #[case(OverlapperType::Bits)]
499 fn test_query_nonexistent_chromosome(#[case] overlapper_type: OverlapperType) {
500 let regions = vec![Region {
501 chr: "chr1".to_string(),
502 start: 100,
503 end: 200,
504 rest: None,
505 }];
506 let rs = RegionSet::from(regions);
507 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
508
509 let query = RegionSet::from(vec![Region {
510 chr: "chr99".to_string(),
511 start: 100,
512 end: 200,
513 rest: None,
514 }]);
515
516 let hits = gi.find_overlaps(&query);
517 assert_eq!(hits.len(), 0);
518 }
519
520 #[rstest]
521 #[case(OverlapperType::AIList)]
522 #[case(OverlapperType::Bits)]
523 fn test_with_metadata(#[case] overlapper_type: OverlapperType) {
524 let regions = vec![
525 Region {
526 chr: "chr1".to_string(),
527 start: 100,
528 end: 200,
529 rest: Some("gene_a".to_string()),
530 },
531 Region {
532 chr: "chr1".to_string(),
533 start: 300,
534 end: 400,
535 rest: Some("gene_b".to_string()),
536 },
537 ];
538 let rs = RegionSet::from(regions);
539 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
540
541 let query = RegionSet::from(vec![Region {
542 chr: "chr1".to_string(),
543 start: 150,
544 end: 250,
545 rest: None,
546 }]);
547
548 let hits = gi.find_overlaps(&query);
549 assert_eq!(hits.len(), 1);
550 assert!(hits[0].1.val.is_some());
551 assert_eq!(hits[0].1.val.as_ref().unwrap(), "gene_a");
552 }
553
554 #[rstest]
555 #[case(OverlapperType::AIList)]
556 #[case(OverlapperType::Bits)]
557 fn test_overlapping_query_regions(#[case] overlapper_type: OverlapperType) {
558 let regions = vec![
559 Region {
560 chr: "chr1".to_string(),
561 start: 100,
562 end: 200,
563 rest: None,
564 },
565 Region {
566 chr: "chr1".to_string(),
567 start: 300,
568 end: 400,
569 rest: None,
570 },
571 ];
572 let rs = RegionSet::from(regions);
573 let gi = rs.into_multi_chrom_overlapper(overlapper_type);
574
575 let query = RegionSet::from(vec![
577 Region {
578 chr: "chr1".to_string(),
579 start: 150,
580 end: 250,
581 rest: None,
582 },
583 Region {
584 chr: "chr1".to_string(),
585 start: 350,
586 end: 450,
587 rest: None,
588 },
589 ]);
590
591 let hits = gi.find_overlaps(&query);
592 assert_eq!(hits.len(), 2);
593 }
594}