1#[derive(Debug, Clone)]
9pub struct Interval<T> {
10 pub start: u64,
12 pub end: u64,
14 pub data: T,
16}
17
18impl<T> Interval<T> {
19 pub fn new(start: u64, end: u64, data: T) -> Self {
21 Self { start, end, data }
22 }
23}
24
25#[derive(Debug, Clone)]
27struct Node<T> {
28 interval: Interval<T>,
29 max_end: u64,
31}
32
33#[derive(Debug, Clone)]
38pub struct IntervalTree<T> {
39 nodes: Vec<Option<Node<T>>>,
40}
41
42impl<T> IntervalTree<T> {
43 pub fn from_unsorted(mut intervals: Vec<Interval<T>>) -> Self {
45 intervals.sort_by_key(|iv| iv.start);
46 Self::from_sorted(intervals)
47 }
48
49 pub fn from_sorted(intervals: Vec<Interval<T>>) -> Self {
51 let n = intervals.len();
52 if n == 0 {
53 return Self { nodes: Vec::new() };
54 }
55
56 let capacity = implicit_tree_size(n);
58 let mut nodes: Vec<Option<Node<T>>> = (0..capacity).map(|_| None).collect();
59
60 let mut sorted: Vec<Option<Interval<T>>> = intervals.into_iter().map(Some).collect();
62
63 build_implicit(&mut nodes, &mut sorted, 0, 0, n);
64 augment_max_end(&mut nodes, 0);
65
66 Self { nodes }
67 }
68
69 pub fn query(&self, start: u64, end: u64) -> Vec<&Interval<T>> {
73 let mut results = Vec::new();
74 if !self.nodes.is_empty() {
75 self.query_recursive(0, start, end, &mut results);
76 }
77 results
78 }
79
80 pub fn count_overlaps(&self, start: u64, end: u64) -> usize {
82 if self.nodes.is_empty() {
83 return 0;
84 }
85 self.count_recursive(0, start, end)
86 }
87
88 pub fn nearest(&self, point: u64) -> Option<&Interval<T>> {
93 if self.nodes.is_empty() {
94 return None;
95 }
96 let mut best: Option<&Interval<T>> = None;
97 let mut best_dist = u64::MAX;
98 self.nearest_recursive(0, point, &mut best, &mut best_dist);
99 best
100 }
101
102 pub fn preceding(&self, point: u64) -> Option<&Interval<T>> {
106 if self.nodes.is_empty() {
107 return None;
108 }
109 let mut best: Option<&Interval<T>> = None;
110 self.preceding_recursive(0, point, &mut best);
111 best
112 }
113
114 pub fn following(&self, point: u64) -> Option<&Interval<T>> {
118 if self.nodes.is_empty() {
119 return None;
120 }
121 let mut best: Option<&Interval<T>> = None;
122 self.following_recursive(0, point, &mut best);
123 best
124 }
125
126 pub fn len(&self) -> usize {
128 self.nodes.iter().filter(|n| n.is_some()).count()
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.nodes.is_empty() || self.nodes.iter().all(|n| n.is_none())
134 }
135
136 pub fn iter(&self) -> impl Iterator<Item = &Interval<T>> {
138 IntervalTreeIter {
139 nodes: &self.nodes,
140 stack: if self.nodes.is_empty() {
141 Vec::new()
142 } else {
143 vec![IterState::Descend(0)]
144 },
145 }
146 }
147
148 fn query_recursive<'a>(
149 &'a self,
150 idx: usize,
151 start: u64,
152 end: u64,
153 results: &mut Vec<&'a Interval<T>>,
154 ) {
155 if idx >= self.nodes.len() {
156 return;
157 }
158 let node = match &self.nodes[idx] {
159 Some(n) => n,
160 None => return,
161 };
162
163 if node.max_end <= start {
165 return;
166 }
167
168 let left = 2 * idx + 1;
170 self.query_recursive(left, start, end, results);
171
172 if node.interval.start < end && node.interval.end > start {
174 results.push(&node.interval);
175 }
176
177 if node.interval.start < end {
179 let right = 2 * idx + 2;
180 self.query_recursive(right, start, end, results);
181 }
182 }
183
184 fn count_recursive(&self, idx: usize, start: u64, end: u64) -> usize {
185 if idx >= self.nodes.len() {
186 return 0;
187 }
188 let node = match &self.nodes[idx] {
189 Some(n) => n,
190 None => return 0,
191 };
192
193 if node.max_end <= start {
194 return 0;
195 }
196
197 let mut count = 0;
198
199 let left = 2 * idx + 1;
200 count += self.count_recursive(left, start, end);
201
202 if node.interval.start < end && node.interval.end > start {
203 count += 1;
204 }
205
206 if node.interval.start < end {
207 let right = 2 * idx + 2;
208 count += self.count_recursive(right, start, end);
209 }
210
211 count
212 }
213
214 fn nearest_recursive<'a>(
215 &'a self,
216 idx: usize,
217 point: u64,
218 best: &mut Option<&'a Interval<T>>,
219 best_dist: &mut u64,
220 ) {
221 if idx >= self.nodes.len() {
222 return;
223 }
224 let node = match &self.nodes[idx] {
225 Some(n) => n,
226 None => return,
227 };
228
229 let dist = if point < node.interval.start {
231 node.interval.start - point
232 } else if point >= node.interval.end {
233 point - node.interval.end + 1
234 } else {
235 0 };
237
238 if dist < *best_dist {
239 *best_dist = dist;
240 *best = Some(&node.interval);
241 }
242
243 if dist == 0 {
244 return; }
246
247 let left = 2 * idx + 1;
248 let right = 2 * idx + 2;
249
250 if point < node.interval.start {
252 self.nearest_recursive(left, point, best, best_dist);
253 if node.interval.start - point <= *best_dist {
254 self.nearest_recursive(right, point, best, best_dist);
255 }
256 } else {
257 self.nearest_recursive(right, point, best, best_dist);
258 self.nearest_recursive(left, point, best, best_dist);
259 }
260 }
261
262 fn preceding_recursive<'a>(
263 &'a self,
264 idx: usize,
265 point: u64,
266 best: &mut Option<&'a Interval<T>>,
267 ) {
268 if idx >= self.nodes.len() {
269 return;
270 }
271 let node = match &self.nodes[idx] {
272 Some(n) => n,
273 None => return,
274 };
275
276 if node.interval.end <= point {
277 let is_better = match best {
279 None => true,
280 Some(b) => node.interval.end > b.end
281 || (node.interval.end == b.end && node.interval.start > b.start),
282 };
283 if is_better {
284 *best = Some(&node.interval);
285 }
286 }
287
288 let left = 2 * idx + 1;
289 let right = 2 * idx + 2;
290
291 self.preceding_recursive(left, point, best);
293 self.preceding_recursive(right, point, best);
295 }
296
297 fn following_recursive<'a>(
298 &'a self,
299 idx: usize,
300 point: u64,
301 best: &mut Option<&'a Interval<T>>,
302 ) {
303 if idx >= self.nodes.len() {
304 return;
305 }
306 let node = match &self.nodes[idx] {
307 Some(n) => n,
308 None => return,
309 };
310
311 if node.interval.start >= point {
312 let is_better = match best {
314 None => true,
315 Some(b) => node.interval.start < b.start,
316 };
317 if is_better {
318 *best = Some(&node.interval);
319 }
320 }
321
322 let left = 2 * idx + 1;
323 let right = 2 * idx + 2;
324
325 if node.interval.start >= point {
327 self.following_recursive(left, point, best);
328 }
329 self.following_recursive(right, point, best);
331 }
332}
333
334fn implicit_tree_size(n: usize) -> usize {
340 if n == 0 {
341 return 0;
342 }
343 let height = (n as f64).log2().ceil() as u32 + 1;
345 (1usize << height) - 1
346}
347
348fn build_implicit<T>(
350 nodes: &mut [Option<Node<T>>],
351 sorted: &mut [Option<Interval<T>>],
352 node_idx: usize,
353 lo: usize,
354 hi: usize,
355) {
356 if lo >= hi || node_idx >= nodes.len() {
357 return;
358 }
359
360 let mid = lo + (hi - lo) / 2;
361
362 if let Some(interval) = sorted[mid].take() {
363 let max_end = interval.end;
364 nodes[node_idx] = Some(Node {
365 interval,
366 max_end,
367 });
368
369 let left = 2 * node_idx + 1;
370 let right = 2 * node_idx + 2;
371
372 build_implicit(nodes, sorted, left, lo, mid);
373 build_implicit(nodes, sorted, right, mid + 1, hi);
374 }
375}
376
377fn augment_max_end<T>(nodes: &mut [Option<Node<T>>], idx: usize) -> u64 {
379 if idx >= nodes.len() {
380 return 0;
381 }
382
383 let node = match &nodes[idx] {
384 Some(n) => n,
385 None => return 0,
386 };
387
388 let own_end = node.interval.end;
389 let left_max = augment_max_end(nodes, 2 * idx + 1);
390 let right_max = augment_max_end(nodes, 2 * idx + 2);
391
392 let max_end = own_end.max(left_max).max(right_max);
393
394 if let Some(ref mut n) = nodes[idx] {
395 n.max_end = max_end;
396 }
397
398 max_end
399}
400
401enum IterState {
406 Descend(usize),
407 Visit(usize),
408}
409
410struct IntervalTreeIter<'a, T> {
411 nodes: &'a [Option<Node<T>>],
412 stack: Vec<IterState>,
413}
414
415impl<'a, T> Iterator for IntervalTreeIter<'a, T> {
416 type Item = &'a Interval<T>;
417
418 fn next(&mut self) -> Option<Self::Item> {
419 loop {
420 let state = self.stack.pop()?;
421 match state {
422 IterState::Descend(idx) => {
423 if idx >= self.nodes.len() {
424 continue;
425 }
426 if self.nodes[idx].is_none() {
427 continue;
428 }
429 self.stack.push(IterState::Descend(2 * idx + 2));
431 self.stack.push(IterState::Visit(idx));
432 self.stack.push(IterState::Descend(2 * idx + 1));
433 }
434 IterState::Visit(idx) => {
435 if let Some(node) = &self.nodes[idx] {
436 return Some(&node.interval);
437 }
438 }
439 }
440 }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 fn iv(start: u64, end: u64) -> Interval<()> {
449 Interval::new(start, end, ())
450 }
451
452 fn iv_data(start: u64, end: u64, data: usize) -> Interval<usize> {
453 Interval::new(start, end, data)
454 }
455
456 #[test]
457 fn empty_tree() {
458 let tree: IntervalTree<()> = IntervalTree::from_unsorted(vec![]);
459 assert!(tree.is_empty());
460 assert_eq!(tree.len(), 0);
461 assert_eq!(tree.query(0, 100).len(), 0);
462 assert_eq!(tree.count_overlaps(0, 100), 0);
463 assert!(tree.nearest(50).is_none());
464 assert!(tree.preceding(50).is_none());
465 assert!(tree.following(50).is_none());
466 assert_eq!(tree.iter().count(), 0);
467 }
468
469 #[test]
470 fn single_interval() {
471 let tree = IntervalTree::from_unsorted(vec![iv(10, 20)]);
472 assert_eq!(tree.len(), 1);
473 assert!(!tree.is_empty());
474
475 assert_eq!(tree.query(5, 15).len(), 1);
476 assert_eq!(tree.query(15, 25).len(), 1);
477 assert_eq!(tree.query(10, 20).len(), 1);
478 assert_eq!(tree.query(0, 10).len(), 0); assert_eq!(tree.query(20, 30).len(), 0); assert_eq!(tree.query(25, 30).len(), 0);
481 }
482
483 #[test]
484 fn many_intervals() {
485 let tree = IntervalTree::from_unsorted(vec![
486 iv(0, 10),
487 iv(5, 15),
488 iv(20, 30),
489 iv(25, 35),
490 iv(50, 60),
491 ]);
492 assert_eq!(tree.len(), 5);
493
494 let hits = tree.query(8, 12);
496 assert_eq!(hits.len(), 2);
497
498 let hits = tree.query(22, 28);
500 assert_eq!(hits.len(), 2);
501
502 let hits = tree.query(40, 45);
504 assert_eq!(hits.len(), 0);
505
506 let hits = tree.query(0, 35);
508 assert_eq!(hits.len(), 4);
509 }
510
511 #[test]
512 fn nested_intervals() {
513 let tree = IntervalTree::from_unsorted(vec![
514 iv(0, 100),
515 iv(10, 90),
516 iv(20, 80),
517 iv(30, 70),
518 iv(40, 60),
519 ]);
520
521 assert_eq!(tree.query(45, 55).len(), 5);
523
524 assert_eq!(tree.query(0, 1).len(), 1);
526 assert_eq!(tree.query(95, 100).len(), 1);
527 }
528
529 #[test]
530 fn adjacent_intervals() {
531 let tree = IntervalTree::from_unsorted(vec![
532 iv(0, 10),
533 iv(10, 20),
534 iv(20, 30),
535 ]);
536
537 assert_eq!(tree.query(10, 20).len(), 1);
539 assert_eq!(tree.query(9, 11).len(), 2);
540 }
541
542 #[test]
543 fn all_same_start() {
544 let tree = IntervalTree::from_unsorted(vec![
545 iv(10, 20),
546 iv(10, 30),
547 iv(10, 40),
548 iv(10, 50),
549 ]);
550
551 assert_eq!(tree.query(10, 11).len(), 4);
552 assert_eq!(tree.query(25, 26).len(), 3);
553 assert_eq!(tree.query(35, 36).len(), 2);
554 assert_eq!(tree.query(45, 46).len(), 1);
555 }
556
557 #[test]
558 fn count_overlaps() {
559 let tree = IntervalTree::from_unsorted(vec![
560 iv(0, 10),
561 iv(5, 15),
562 iv(20, 30),
563 ]);
564 assert_eq!(tree.count_overlaps(8, 12), 2);
565 assert_eq!(tree.count_overlaps(25, 35), 1);
566 assert_eq!(tree.count_overlaps(16, 19), 0);
567 }
568
569 #[test]
570 fn nearest_basic() {
571 let tree = IntervalTree::from_unsorted(vec![
572 iv(10, 20),
573 iv(30, 40),
574 iv(60, 70),
575 ]);
576
577 let n = tree.nearest(15).unwrap();
579 assert_eq!(n.start, 10);
580
581 let n = tree.nearest(28).unwrap();
583 assert_eq!(n.start, 30);
584
585 let n = tree.nearest(0).unwrap();
587 assert_eq!(n.start, 10);
588
589 let n = tree.nearest(100).unwrap();
591 assert_eq!(n.start, 60);
592 }
593
594 #[test]
595 fn preceding_basic() {
596 let tree = IntervalTree::from_unsorted(vec![
597 iv(10, 20),
598 iv(30, 40),
599 iv(60, 70),
600 ]);
601
602 assert!(tree.preceding(5).is_none());
604
605 let p = tree.preceding(25).unwrap();
607 assert_eq!(p.start, 10);
608
609 let p = tree.preceding(50).unwrap();
611 assert_eq!(p.start, 30);
612
613 let p = tree.preceding(100).unwrap();
615 assert_eq!(p.start, 60);
616 }
617
618 #[test]
619 fn following_basic() {
620 let tree = IntervalTree::from_unsorted(vec![
621 iv(10, 20),
622 iv(30, 40),
623 iv(60, 70),
624 ]);
625
626 let f = tree.following(0).unwrap();
628 assert_eq!(f.start, 10);
629
630 let f = tree.following(25).unwrap();
632 assert_eq!(f.start, 30);
633
634 let f = tree.following(30).unwrap();
636 assert_eq!(f.start, 30);
637
638 assert!(tree.following(75).is_none());
640 }
641
642 #[test]
643 fn preceding_at_boundary() {
644 let tree = IntervalTree::from_unsorted(vec![iv(10, 20)]);
645
646 let p = tree.preceding(20).unwrap();
648 assert_eq!(p.start, 10);
649
650 assert!(tree.preceding(15).is_none());
652 }
653
654 #[test]
655 fn following_at_boundary() {
656 let tree = IntervalTree::from_unsorted(vec![iv(10, 20)]);
657
658 let f = tree.following(10).unwrap();
660 assert_eq!(f.start, 10);
661
662 assert!(tree.following(15).is_none());
664 }
665
666 #[test]
667 fn iter_in_order() {
668 let tree = IntervalTree::from_unsorted(vec![
669 iv(30, 40),
670 iv(10, 20),
671 iv(50, 60),
672 iv(0, 5),
673 ]);
674
675 let starts: Vec<u64> = tree.iter().map(|i| i.start).collect();
676 assert_eq!(starts, vec![0, 10, 30, 50]);
678 }
679
680 #[test]
681 fn from_sorted() {
682 let sorted = vec![iv(0, 10), iv(10, 20), iv(20, 30)];
683 let tree = IntervalTree::from_sorted(sorted);
684 assert_eq!(tree.len(), 3);
685 assert_eq!(tree.query(5, 25).len(), 3); assert_eq!(tree.query(5, 15).len(), 2); }
688
689 #[test]
690 fn data_preserved() {
691 let tree = IntervalTree::from_unsorted(vec![
692 iv_data(10, 20, 42),
693 iv_data(30, 40, 99),
694 ]);
695
696 let hits = tree.query(15, 35);
697 assert_eq!(hits.len(), 2);
698 let mut data: Vec<usize> = hits.iter().map(|h| h.data).collect();
699 data.sort();
700 assert_eq!(data, vec![42, 99]);
701 }
702
703 #[test]
704 fn large_tree() {
705 let intervals: Vec<Interval<usize>> = (0..1000)
706 .map(|i| iv_data(i * 10, i * 10 + 5, i as usize))
707 .collect();
708 let tree = IntervalTree::from_unsorted(intervals);
709 assert_eq!(tree.len(), 1000);
710
711 let hits = tree.query(500, 510);
713 assert_eq!(hits.len(), 1);
714 assert_eq!(hits[0].data, 50);
715
716 let hits = tree.query(0, 10000);
718 assert_eq!(hits.len(), 1000);
719 }
720
721 #[test]
722 fn query_matches_linear_scan() {
723 let intervals = vec![
725 iv(5, 15),
726 iv(10, 25),
727 iv(20, 35),
728 iv(30, 45),
729 iv(40, 55),
730 iv(0, 100),
731 iv(50, 60),
732 iv(70, 80),
733 ];
734
735 let tree = IntervalTree::from_unsorted(intervals.clone());
736
737 for start in (0..100).step_by(7) {
738 for end in (start + 1..110).step_by(11) {
739 let tree_count = tree.count_overlaps(start, end);
740 let linear_count = intervals
741 .iter()
742 .filter(|iv| iv.start < end && iv.end > start)
743 .count();
744 assert_eq!(
745 tree_count, linear_count,
746 "mismatch for query [{}, {}): tree={}, linear={}",
747 start, end, tree_count, linear_count
748 );
749 }
750 }
751 }
752
753 #[test]
754 fn two_intervals() {
755 let tree = IntervalTree::from_unsorted(vec![iv(0, 10), iv(20, 30)]);
756 assert_eq!(tree.len(), 2);
757 assert_eq!(tree.query(5, 25).len(), 2);
758 assert_eq!(tree.query(5, 15).len(), 1);
759 assert_eq!(tree.query(25, 35).len(), 1);
760 assert_eq!(tree.query(12, 18).len(), 0);
761 }
762}