Skip to main content

cyanea_omics/
interval_tree.rs

1//! Static augmented interval tree for fast overlap queries.
2//!
3//! [`IntervalTree`] stores intervals in an implicit balanced BST layout
4//! (nodes in a contiguous `Vec`, children of node `i` at `2i+1`/`2i+2`).
5//! Build once, query many times with O(log n + k) overlap queries.
6
7/// A generic interval with associated data.
8#[derive(Debug, Clone)]
9pub struct Interval<T> {
10    /// Start coordinate (inclusive).
11    pub start: u64,
12    /// End coordinate (exclusive).
13    pub end: u64,
14    /// Payload associated with this interval.
15    pub data: T,
16}
17
18impl<T> Interval<T> {
19    /// Create a new interval.
20    pub fn new(start: u64, end: u64, data: T) -> Self {
21        Self { start, end, data }
22    }
23}
24
25/// Internal node in the implicit BST.
26#[derive(Debug, Clone)]
27struct Node<T> {
28    interval: Interval<T>,
29    /// Maximum end coordinate in this subtree.
30    max_end: u64,
31}
32
33/// A static augmented interval tree using an implicit BST layout.
34///
35/// Built once from a set of intervals, then supports efficient overlap queries.
36/// The tree cannot be modified after construction.
37#[derive(Debug, Clone)]
38pub struct IntervalTree<T> {
39    nodes: Vec<Option<Node<T>>>,
40}
41
42impl<T> IntervalTree<T> {
43    /// Build an interval tree from unsorted intervals. O(n log n).
44    pub fn from_unsorted(mut intervals: Vec<Interval<T>>) -> Self {
45        intervals.sort_by_key(|iv| iv.start);
46        Self::from_sorted(intervals)
47    }
48
49    /// Build an interval tree from intervals sorted by start coordinate. O(n).
50    pub fn from_sorted(intervals: Vec<Interval<T>>) -> Self {
51        let n = intervals.len();
52        if n == 0 {
53            return Self { nodes: Vec::new() };
54        }
55
56        // Compute the required array size for the implicit BST
57        let capacity = implicit_tree_size(n);
58        let mut nodes: Vec<Option<Node<T>>> = (0..capacity).map(|_| None).collect();
59
60        // Convert intervals into an indexable vec
61        let mut sorted: Vec<Option<Interval<T>>> = intervals.into_iter().map(Some).collect();
62
63        build_implicit(&mut nodes, &mut sorted, 0, 0, n);
64        augment_max_end(&mut nodes, 0);
65
66        Self { nodes }
67    }
68
69    /// Query all intervals overlapping the range `[start, end)`.
70    ///
71    /// Returns references to all intervals where `interval.start < end && interval.end > start`.
72    pub fn query(&self, start: u64, end: u64) -> Vec<&Interval<T>> {
73        let mut results = Vec::new();
74        if !self.nodes.is_empty() {
75            self.query_recursive(0, start, end, &mut results);
76        }
77        results
78    }
79
80    /// Count intervals overlapping the range `[start, end)` without allocating.
81    pub fn count_overlaps(&self, start: u64, end: u64) -> usize {
82        if self.nodes.is_empty() {
83            return 0;
84        }
85        self.count_recursive(0, start, end)
86    }
87
88    /// Find the nearest interval to a point.
89    ///
90    /// Returns the interval whose midpoint is closest to `point`.
91    /// If multiple intervals are equidistant, returns one arbitrarily.
92    pub fn nearest(&self, point: u64) -> Option<&Interval<T>> {
93        if self.nodes.is_empty() {
94            return None;
95        }
96        let mut best: Option<&Interval<T>> = None;
97        let mut best_dist = u64::MAX;
98        self.nearest_recursive(0, point, &mut best, &mut best_dist);
99        best
100    }
101
102    /// Find the nearest interval that ends at or before `point`.
103    ///
104    /// Returns the interval with the largest `end` that is `<= point`.
105    pub fn preceding(&self, point: u64) -> Option<&Interval<T>> {
106        if self.nodes.is_empty() {
107            return None;
108        }
109        let mut best: Option<&Interval<T>> = None;
110        self.preceding_recursive(0, point, &mut best);
111        best
112    }
113
114    /// Find the nearest interval that starts at or after `point`.
115    ///
116    /// Returns the interval with the smallest `start` that is `>= point`.
117    pub fn following(&self, point: u64) -> Option<&Interval<T>> {
118        if self.nodes.is_empty() {
119            return None;
120        }
121        let mut best: Option<&Interval<T>> = None;
122        self.following_recursive(0, point, &mut best);
123        best
124    }
125
126    /// Number of intervals in the tree.
127    pub fn len(&self) -> usize {
128        self.nodes.iter().filter(|n| n.is_some()).count()
129    }
130
131    /// Whether the tree contains no intervals.
132    pub fn is_empty(&self) -> bool {
133        self.nodes.is_empty() || self.nodes.iter().all(|n| n.is_none())
134    }
135
136    /// Iterate over all intervals in the tree (in-order traversal).
137    pub fn iter(&self) -> impl Iterator<Item = &Interval<T>> {
138        IntervalTreeIter {
139            nodes: &self.nodes,
140            stack: if self.nodes.is_empty() {
141                Vec::new()
142            } else {
143                vec![IterState::Descend(0)]
144            },
145        }
146    }
147
148    fn query_recursive<'a>(
149        &'a self,
150        idx: usize,
151        start: u64,
152        end: u64,
153        results: &mut Vec<&'a Interval<T>>,
154    ) {
155        if idx >= self.nodes.len() {
156            return;
157        }
158        let node = match &self.nodes[idx] {
159            Some(n) => n,
160            None => return,
161        };
162
163        // Prune: if max_end in this subtree <= query start, no overlap possible
164        if node.max_end <= start {
165            return;
166        }
167
168        // Search left subtree
169        let left = 2 * idx + 1;
170        self.query_recursive(left, start, end, results);
171
172        // Check current node
173        if node.interval.start < end && node.interval.end > start {
174            results.push(&node.interval);
175        }
176
177        // Prune right: if node.start >= end, right subtree has only larger starts
178        if node.interval.start < end {
179            let right = 2 * idx + 2;
180            self.query_recursive(right, start, end, results);
181        }
182    }
183
184    fn count_recursive(&self, idx: usize, start: u64, end: u64) -> usize {
185        if idx >= self.nodes.len() {
186            return 0;
187        }
188        let node = match &self.nodes[idx] {
189            Some(n) => n,
190            None => return 0,
191        };
192
193        if node.max_end <= start {
194            return 0;
195        }
196
197        let mut count = 0;
198
199        let left = 2 * idx + 1;
200        count += self.count_recursive(left, start, end);
201
202        if node.interval.start < end && node.interval.end > start {
203            count += 1;
204        }
205
206        if node.interval.start < end {
207            let right = 2 * idx + 2;
208            count += self.count_recursive(right, start, end);
209        }
210
211        count
212    }
213
214    fn nearest_recursive<'a>(
215        &'a self,
216        idx: usize,
217        point: u64,
218        best: &mut Option<&'a Interval<T>>,
219        best_dist: &mut u64,
220    ) {
221        if idx >= self.nodes.len() {
222            return;
223        }
224        let node = match &self.nodes[idx] {
225            Some(n) => n,
226            None => return,
227        };
228
229        // Distance from point to this interval
230        let dist = if point < node.interval.start {
231            node.interval.start - point
232        } else if point >= node.interval.end {
233            point - node.interval.end + 1
234        } else {
235            0 // point is inside the interval
236        };
237
238        if dist < *best_dist {
239            *best_dist = dist;
240            *best = Some(&node.interval);
241        }
242
243        if dist == 0 {
244            return; // Can't do better than overlapping
245        }
246
247        let left = 2 * idx + 1;
248        let right = 2 * idx + 2;
249
250        // Search both subtrees
251        if point < node.interval.start {
252            self.nearest_recursive(left, point, best, best_dist);
253            if node.interval.start - point <= *best_dist {
254                self.nearest_recursive(right, point, best, best_dist);
255            }
256        } else {
257            self.nearest_recursive(right, point, best, best_dist);
258            self.nearest_recursive(left, point, best, best_dist);
259        }
260    }
261
262    fn preceding_recursive<'a>(
263        &'a self,
264        idx: usize,
265        point: u64,
266        best: &mut Option<&'a Interval<T>>,
267    ) {
268        if idx >= self.nodes.len() {
269            return;
270        }
271        let node = match &self.nodes[idx] {
272            Some(n) => n,
273            None => return,
274        };
275
276        if node.interval.end <= point {
277            // This interval ends before point — candidate
278            let is_better = match best {
279                None => true,
280                Some(b) => node.interval.end > b.end
281                    || (node.interval.end == b.end && node.interval.start > b.start),
282            };
283            if is_better {
284                *best = Some(&node.interval);
285            }
286        }
287
288        let left = 2 * idx + 1;
289        let right = 2 * idx + 2;
290
291        // Always check left (may have intervals ending before point)
292        self.preceding_recursive(left, point, best);
293        // Check right subtree too (intervals may end before point but start after current)
294        self.preceding_recursive(right, point, best);
295    }
296
297    fn following_recursive<'a>(
298        &'a self,
299        idx: usize,
300        point: u64,
301        best: &mut Option<&'a Interval<T>>,
302    ) {
303        if idx >= self.nodes.len() {
304            return;
305        }
306        let node = match &self.nodes[idx] {
307            Some(n) => n,
308            None => return,
309        };
310
311        if node.interval.start >= point {
312            // This interval starts at or after point — candidate
313            let is_better = match best {
314                None => true,
315                Some(b) => node.interval.start < b.start,
316            };
317            if is_better {
318                *best = Some(&node.interval);
319            }
320        }
321
322        let left = 2 * idx + 1;
323        let right = 2 * idx + 2;
324
325        // If current node starts after point, left subtree may have closer intervals
326        if node.interval.start >= point {
327            self.following_recursive(left, point, best);
328        }
329        // Always check right subtree
330        self.following_recursive(right, point, best);
331    }
332}
333
334// ---------------------------------------------------------------------------
335// Implicit BST construction helpers
336// ---------------------------------------------------------------------------
337
338/// Compute the array size needed for an implicit BST with `n` elements.
339fn implicit_tree_size(n: usize) -> usize {
340    if n == 0 {
341        return 0;
342    }
343    // Height of the tree
344    let height = (n as f64).log2().ceil() as u32 + 1;
345    (1usize << height) - 1
346}
347
348/// Recursively build the implicit BST by placing the median at each node.
349fn build_implicit<T>(
350    nodes: &mut [Option<Node<T>>],
351    sorted: &mut [Option<Interval<T>>],
352    node_idx: usize,
353    lo: usize,
354    hi: usize,
355) {
356    if lo >= hi || node_idx >= nodes.len() {
357        return;
358    }
359
360    let mid = lo + (hi - lo) / 2;
361
362    if let Some(interval) = sorted[mid].take() {
363        let max_end = interval.end;
364        nodes[node_idx] = Some(Node {
365            interval,
366            max_end,
367        });
368
369        let left = 2 * node_idx + 1;
370        let right = 2 * node_idx + 2;
371
372        build_implicit(nodes, sorted, left, lo, mid);
373        build_implicit(nodes, sorted, right, mid + 1, hi);
374    }
375}
376
377/// Post-order traversal to compute augmented max_end values.
378fn augment_max_end<T>(nodes: &mut [Option<Node<T>>], idx: usize) -> u64 {
379    if idx >= nodes.len() {
380        return 0;
381    }
382
383    let node = match &nodes[idx] {
384        Some(n) => n,
385        None => return 0,
386    };
387
388    let own_end = node.interval.end;
389    let left_max = augment_max_end(nodes, 2 * idx + 1);
390    let right_max = augment_max_end(nodes, 2 * idx + 2);
391
392    let max_end = own_end.max(left_max).max(right_max);
393
394    if let Some(ref mut n) = nodes[idx] {
395        n.max_end = max_end;
396    }
397
398    max_end
399}
400
401// ---------------------------------------------------------------------------
402// Iterator
403// ---------------------------------------------------------------------------
404
405enum IterState {
406    Descend(usize),
407    Visit(usize),
408}
409
410struct IntervalTreeIter<'a, T> {
411    nodes: &'a [Option<Node<T>>],
412    stack: Vec<IterState>,
413}
414
415impl<'a, T> Iterator for IntervalTreeIter<'a, T> {
416    type Item = &'a Interval<T>;
417
418    fn next(&mut self) -> Option<Self::Item> {
419        loop {
420            let state = self.stack.pop()?;
421            match state {
422                IterState::Descend(idx) => {
423                    if idx >= self.nodes.len() {
424                        continue;
425                    }
426                    if self.nodes[idx].is_none() {
427                        continue;
428                    }
429                    // Push right, then visit, then left (so left is processed first)
430                    self.stack.push(IterState::Descend(2 * idx + 2));
431                    self.stack.push(IterState::Visit(idx));
432                    self.stack.push(IterState::Descend(2 * idx + 1));
433                }
434                IterState::Visit(idx) => {
435                    if let Some(node) = &self.nodes[idx] {
436                        return Some(&node.interval);
437                    }
438                }
439            }
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    fn iv(start: u64, end: u64) -> Interval<()> {
449        Interval::new(start, end, ())
450    }
451
452    fn iv_data(start: u64, end: u64, data: usize) -> Interval<usize> {
453        Interval::new(start, end, data)
454    }
455
456    #[test]
457    fn empty_tree() {
458        let tree: IntervalTree<()> = IntervalTree::from_unsorted(vec![]);
459        assert!(tree.is_empty());
460        assert_eq!(tree.len(), 0);
461        assert_eq!(tree.query(0, 100).len(), 0);
462        assert_eq!(tree.count_overlaps(0, 100), 0);
463        assert!(tree.nearest(50).is_none());
464        assert!(tree.preceding(50).is_none());
465        assert!(tree.following(50).is_none());
466        assert_eq!(tree.iter().count(), 0);
467    }
468
469    #[test]
470    fn single_interval() {
471        let tree = IntervalTree::from_unsorted(vec![iv(10, 20)]);
472        assert_eq!(tree.len(), 1);
473        assert!(!tree.is_empty());
474
475        assert_eq!(tree.query(5, 15).len(), 1);
476        assert_eq!(tree.query(15, 25).len(), 1);
477        assert_eq!(tree.query(10, 20).len(), 1);
478        assert_eq!(tree.query(0, 10).len(), 0); // abutting
479        assert_eq!(tree.query(20, 30).len(), 0); // abutting
480        assert_eq!(tree.query(25, 30).len(), 0);
481    }
482
483    #[test]
484    fn many_intervals() {
485        let tree = IntervalTree::from_unsorted(vec![
486            iv(0, 10),
487            iv(5, 15),
488            iv(20, 30),
489            iv(25, 35),
490            iv(50, 60),
491        ]);
492        assert_eq!(tree.len(), 5);
493
494        // Query overlapping first two
495        let hits = tree.query(8, 12);
496        assert_eq!(hits.len(), 2);
497
498        // Query overlapping middle two
499        let hits = tree.query(22, 28);
500        assert_eq!(hits.len(), 2);
501
502        // Query overlapping none (gap)
503        let hits = tree.query(40, 45);
504        assert_eq!(hits.len(), 0);
505
506        // Query overlapping all in first cluster
507        let hits = tree.query(0, 35);
508        assert_eq!(hits.len(), 4);
509    }
510
511    #[test]
512    fn nested_intervals() {
513        let tree = IntervalTree::from_unsorted(vec![
514            iv(0, 100),
515            iv(10, 90),
516            iv(20, 80),
517            iv(30, 70),
518            iv(40, 60),
519        ]);
520
521        // Point query in center should hit all
522        assert_eq!(tree.query(45, 55).len(), 5);
523
524        // Point query at edge
525        assert_eq!(tree.query(0, 1).len(), 1);
526        assert_eq!(tree.query(95, 100).len(), 1);
527    }
528
529    #[test]
530    fn adjacent_intervals() {
531        let tree = IntervalTree::from_unsorted(vec![
532            iv(0, 10),
533            iv(10, 20),
534            iv(20, 30),
535        ]);
536
537        // Abutting intervals don't overlap in half-open semantics
538        assert_eq!(tree.query(10, 20).len(), 1);
539        assert_eq!(tree.query(9, 11).len(), 2);
540    }
541
542    #[test]
543    fn all_same_start() {
544        let tree = IntervalTree::from_unsorted(vec![
545            iv(10, 20),
546            iv(10, 30),
547            iv(10, 40),
548            iv(10, 50),
549        ]);
550
551        assert_eq!(tree.query(10, 11).len(), 4);
552        assert_eq!(tree.query(25, 26).len(), 3);
553        assert_eq!(tree.query(35, 36).len(), 2);
554        assert_eq!(tree.query(45, 46).len(), 1);
555    }
556
557    #[test]
558    fn count_overlaps() {
559        let tree = IntervalTree::from_unsorted(vec![
560            iv(0, 10),
561            iv(5, 15),
562            iv(20, 30),
563        ]);
564        assert_eq!(tree.count_overlaps(8, 12), 2);
565        assert_eq!(tree.count_overlaps(25, 35), 1);
566        assert_eq!(tree.count_overlaps(16, 19), 0);
567    }
568
569    #[test]
570    fn nearest_basic() {
571        let tree = IntervalTree::from_unsorted(vec![
572            iv(10, 20),
573            iv(30, 40),
574            iv(60, 70),
575        ]);
576
577        // Point inside an interval
578        let n = tree.nearest(15).unwrap();
579        assert_eq!(n.start, 10);
580
581        // Point between intervals — closer to [30,40)
582        let n = tree.nearest(28).unwrap();
583        assert_eq!(n.start, 30);
584
585        // Point before all intervals
586        let n = tree.nearest(0).unwrap();
587        assert_eq!(n.start, 10);
588
589        // Point after all intervals
590        let n = tree.nearest(100).unwrap();
591        assert_eq!(n.start, 60);
592    }
593
594    #[test]
595    fn preceding_basic() {
596        let tree = IntervalTree::from_unsorted(vec![
597            iv(10, 20),
598            iv(30, 40),
599            iv(60, 70),
600        ]);
601
602        // Before first interval
603        assert!(tree.preceding(5).is_none());
604
605        // After first interval
606        let p = tree.preceding(25).unwrap();
607        assert_eq!(p.start, 10);
608
609        // After second interval
610        let p = tree.preceding(50).unwrap();
611        assert_eq!(p.start, 30);
612
613        // After all intervals
614        let p = tree.preceding(100).unwrap();
615        assert_eq!(p.start, 60);
616    }
617
618    #[test]
619    fn following_basic() {
620        let tree = IntervalTree::from_unsorted(vec![
621            iv(10, 20),
622            iv(30, 40),
623            iv(60, 70),
624        ]);
625
626        // Before first interval
627        let f = tree.following(0).unwrap();
628        assert_eq!(f.start, 10);
629
630        // Between intervals
631        let f = tree.following(25).unwrap();
632        assert_eq!(f.start, 30);
633
634        // At interval start
635        let f = tree.following(30).unwrap();
636        assert_eq!(f.start, 30);
637
638        // After all intervals
639        assert!(tree.following(75).is_none());
640    }
641
642    #[test]
643    fn preceding_at_boundary() {
644        let tree = IntervalTree::from_unsorted(vec![iv(10, 20)]);
645
646        // End == point: preceding should find it (end <= point)
647        let p = tree.preceding(20).unwrap();
648        assert_eq!(p.start, 10);
649
650        // End > point: not preceding
651        assert!(tree.preceding(15).is_none());
652    }
653
654    #[test]
655    fn following_at_boundary() {
656        let tree = IntervalTree::from_unsorted(vec![iv(10, 20)]);
657
658        // start == point
659        let f = tree.following(10).unwrap();
660        assert_eq!(f.start, 10);
661
662        // start < point
663        assert!(tree.following(15).is_none());
664    }
665
666    #[test]
667    fn iter_in_order() {
668        let tree = IntervalTree::from_unsorted(vec![
669            iv(30, 40),
670            iv(10, 20),
671            iv(50, 60),
672            iv(0, 5),
673        ]);
674
675        let starts: Vec<u64> = tree.iter().map(|i| i.start).collect();
676        // In-order traversal should yield sorted by start
677        assert_eq!(starts, vec![0, 10, 30, 50]);
678    }
679
680    #[test]
681    fn from_sorted() {
682        let sorted = vec![iv(0, 10), iv(10, 20), iv(20, 30)];
683        let tree = IntervalTree::from_sorted(sorted);
684        assert_eq!(tree.len(), 3);
685        assert_eq!(tree.query(5, 25).len(), 3); // all three overlap [5, 25)
686        assert_eq!(tree.query(5, 15).len(), 2); // [0,10) and [10,20)
687    }
688
689    #[test]
690    fn data_preserved() {
691        let tree = IntervalTree::from_unsorted(vec![
692            iv_data(10, 20, 42),
693            iv_data(30, 40, 99),
694        ]);
695
696        let hits = tree.query(15, 35);
697        assert_eq!(hits.len(), 2);
698        let mut data: Vec<usize> = hits.iter().map(|h| h.data).collect();
699        data.sort();
700        assert_eq!(data, vec![42, 99]);
701    }
702
703    #[test]
704    fn large_tree() {
705        let intervals: Vec<Interval<usize>> = (0..1000)
706            .map(|i| iv_data(i * 10, i * 10 + 5, i as usize))
707            .collect();
708        let tree = IntervalTree::from_unsorted(intervals);
709        assert_eq!(tree.len(), 1000);
710
711        // Query a small range
712        let hits = tree.query(500, 510);
713        assert_eq!(hits.len(), 1);
714        assert_eq!(hits[0].data, 50);
715
716        // Query a wider range
717        let hits = tree.query(0, 10000);
718        assert_eq!(hits.len(), 1000);
719    }
720
721    #[test]
722    fn query_matches_linear_scan() {
723        // Property test: tree query results should match linear scan
724        let intervals = vec![
725            iv(5, 15),
726            iv(10, 25),
727            iv(20, 35),
728            iv(30, 45),
729            iv(40, 55),
730            iv(0, 100),
731            iv(50, 60),
732            iv(70, 80),
733        ];
734
735        let tree = IntervalTree::from_unsorted(intervals.clone());
736
737        for start in (0..100).step_by(7) {
738            for end in (start + 1..110).step_by(11) {
739                let tree_count = tree.count_overlaps(start, end);
740                let linear_count = intervals
741                    .iter()
742                    .filter(|iv| iv.start < end && iv.end > start)
743                    .count();
744                assert_eq!(
745                    tree_count, linear_count,
746                    "mismatch for query [{}, {}): tree={}, linear={}",
747                    start, end, tree_count, linear_count
748                );
749            }
750        }
751    }
752
753    #[test]
754    fn two_intervals() {
755        let tree = IntervalTree::from_unsorted(vec![iv(0, 10), iv(20, 30)]);
756        assert_eq!(tree.len(), 2);
757        assert_eq!(tree.query(5, 25).len(), 2);
758        assert_eq!(tree.query(5, 15).len(), 1);
759        assert_eq!(tree.query(25, 35).len(), 1);
760        assert_eq!(tree.query(12, 18).len(), 0);
761    }
762}