use alloc::boxed::Box;
use core::{fmt::Debug, ops::Range};
use crate::{
entry::Entry,
interval::Interval,
iter::{
ContainsPruner, DuringPruner, FinishesPruner, MeetsPruner, MetByPruner, OverlapsPruner,
OwnedIter, PrecededByPruner, PrecedesPruner, PruningIter, RefIter, StartsPruner,
},
node::{remove_recurse, Node, RemoveResult},
};
#[derive(Debug, Clone)]
pub struct IntervalTree<R, V>(Option<Box<Node<R, V>>>);
impl<R, V> Default for IntervalTree<R, V> {
fn default() -> Self {
Self(Default::default())
}
}
impl<R, V> IntervalTree<R, V>
where
R: Ord + Clone + Debug,
{
pub fn insert(&mut self, range: Range<R>, value: V) -> Option<V> {
let interval = Interval::from(range);
match self.0 {
Some(ref mut v) => v.insert(interval, value),
None => {
self.0 = Some(Box::new(Node::new(interval, value)));
None
}
}
}
pub fn get(&self, range: &Range<R>) -> Option<&V> {
self.0.as_ref().and_then(|v| v.get(range))
}
pub fn get_mut(&mut self, range: &Range<R>) -> Option<&mut V> {
self.0.as_mut().and_then(|v| v.get_mut(range))
}
pub fn contains_key(&self, range: &Range<R>) -> bool {
self.get(range).is_some()
}
pub fn remove(&mut self, range: &Range<R>) -> Option<V> {
match remove_recurse(&mut self.0, range)? {
RemoveResult::Removed(v) => Some(v),
RemoveResult::ParentUnlink => unreachable!(),
}
}
pub fn entry(&mut self, range: Range<R>) -> Entry<'_, R, V> {
Entry::new(range, self)
}
pub fn iter(&self) -> impl Iterator<Item = (&Range<R>, &V)> {
self.0
.iter()
.flat_map(|v| RefIter::new(v))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_overlaps<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, OverlapsPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_precedes<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, PrecedesPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_preceded_by<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, PrecededByPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_meets<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, MeetsPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_met_by<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, MetByPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_starts<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, StartsPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_finishes<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, FinishesPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_during<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, DuringPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn iter_contains<'a: 'b, 'b>(
&'a self,
range: &'b Range<R>,
) -> impl Iterator<Item = (&'a Range<R>, &'a V)> + 'b {
self.0
.iter()
.flat_map(|v| PruningIter::new(v, range, ContainsPruner))
.map(|v| (v.interval().as_range(), v.value()))
}
pub fn max_interval_end(&self) -> Option<&R> {
self.0.as_ref().map(|root| root.subtree_max())
}
}
impl<R, V> core::iter::IntoIterator for IntervalTree<R, V> {
type Item = (Range<R>, V);
type IntoIter = OwnedIter<R, V>;
fn into_iter(self) -> Self::IntoIter {
OwnedIter::new(self.0)
}
}
#[cfg(test)]
mod tests {
use std::prelude::v1::*;
use std::{
collections::{HashMap, HashSet},
sync::{atomic::AtomicUsize, Arc},
};
use proptest::prelude::*;
use super::*;
use crate::test_utils::{arbitrary_range, Lfsr, NodeFilterCount};
#[test]
fn test_insert_contains() {
let mut t = IntervalTree::default();
t.insert(42..45, 1);
t.insert(22..23, 2);
t.insert(25..29, 3);
assert!(t.contains_key(&(42..45)));
assert!(t.contains_key(&(22..23)));
assert!(t.contains_key(&(25..29)));
assert!(!t.contains_key(&(42..46)));
assert!(!t.contains_key(&(42..44)));
assert!(!t.contains_key(&(41..45)));
assert!(!t.contains_key(&(43..45)));
validate_tree_structure(&t);
}
#[test]
fn test_insert_refs() {
let mut t = IntervalTree::default();
t.insert(42..45, "bananas");
assert!(t.contains_key(&(42..45)));
validate_tree_structure(&t);
}
const N_VALUES: usize = 200;
#[derive(Debug)]
enum Op {
Insert(Range<usize>, usize),
Get(Range<usize>),
ContainsKey(Range<usize>),
Update(Range<usize>, usize),
Remove(Range<usize>),
}
fn arbitrary_op() -> impl Strategy<Value = Op> {
prop_oneof![
(arbitrary_range(), any::<usize>()).prop_map(|(r, v)| Op::Insert(r, v)),
(arbitrary_range(), any::<usize>()).prop_map(|(r, v)| Op::Update(r, v)),
arbitrary_range().prop_map(Op::Get),
arbitrary_range().prop_map(Op::ContainsKey),
arbitrary_range().prop_map(Op::Remove),
]
}
proptest! {
#[test]
fn prop_insert_contains(
a in prop::collection::hash_set(arbitrary_range(), 0..N_VALUES),
b in prop::collection::hash_set(arbitrary_range(), 0..N_VALUES),
) {
let mut t = IntervalTree::default();
for v in &a {
assert!(!t.contains_key(v));
}
for v in &a {
t.insert(v.clone(), 42);
}
for v in &a {
assert!(t.contains_key(v));
}
for v in b.difference(&a) {
assert!(!t.contains_key(v));
}
validate_tree_structure(&t);
}
#[test]
fn prop_range_to_value_mapping(
values in prop::collection::hash_map(arbitrary_range(), any::<usize>(), 0..N_VALUES),
) {
let mut t = IntervalTree::default();
let mut control = HashMap::with_capacity(values.len());
for (range, v) in &values {
assert_eq!(t.insert(range.clone(), v), control.insert(range, v));
}
validate_tree_structure(&t);
for range in values.keys() {
assert_eq!(t.get(range), control.get(range));
}
for (range, v) in control {
assert_eq!(t.remove(range).unwrap(), v);
}
validate_tree_structure(&t);
}
#[test]
fn prop_insert_contains_remove(
values in prop::collection::hash_set(arbitrary_range(), 0..N_VALUES),
) {
let mut t = IntervalTree::default();
for v in &values {
t.insert(v.clone(), 42);
}
validate_tree_structure(&t);
for v in &values {
assert!(t.contains_key(v));
assert_eq!(t.remove(v), Some(42));
assert!(!t.contains_key(v));
assert_eq!(t.remove(v), None);
validate_tree_structure(&t);
}
assert_eq!(t.remove(&(N_VALUES..N_VALUES+1)), None);
}
#[test]
fn prop_tree_operations(
ops in prop::collection::vec(arbitrary_op(), 1..50),
) {
let mut t = IntervalTree::default();
let mut model = HashMap::new();
for op in ops {
match op {
Op::Insert(range, v) => {
assert_eq!(t.insert(range.clone(), v), model.insert(range, v));
},
Op::Update(range, value) => {
assert_eq!(t.get_mut(&range), model.get_mut(&range));
if let Some(v) = t.get_mut(&range) {
*v = value;
*model.get_mut(&range).unwrap() = value;
}
assert_eq!(t.get(&range), model.get(&range));
},
Op::Get(range) => {
assert_eq!(t.get(&range), model.get(&range));
},
Op::ContainsKey(range) => {
assert_eq!(t.contains_key(&range), model.contains_key(&range));
},
Op::Remove(range) => {
assert_eq!(t.remove(&range), model.remove(&range));
},
}
validate_tree_structure(&t);
}
for (range, _v) in model {
assert!(t.contains_key(&range));
}
}
#[test]
fn prop_iter(
values in prop::collection::hash_map(
arbitrary_range(), any::<usize>(),
0..N_VALUES
),
) {
let mut t = IntervalTree::default();
for (range, value) in &values {
t.insert(range.clone(), *value);
}
let tuples = t.iter().collect::<Vec<_>>();
{
let tuples2 = t.iter().collect::<Vec<(&Range<usize>, &usize)>>();
assert_eq!(tuples, tuples2);
}
for window in tuples.windows(2) {
let a = Interval::from(window[0].0.clone());
let b = Interval::from(window[1].0.clone());
assert!(a < b);
}
let tuples = tuples
.into_iter()
.map(|(r, v)| (r.clone(), *v))
.collect::<HashMap<_, _>>();
assert_eq!(tuples, values);
}
#[test]
fn prop_into_iter(
values in prop::collection::hash_map(
arbitrary_range(), any::<usize>(),
0..N_VALUES
),
) {
let mut t = IntervalTree::default();
for (range, value) in &values {
t.insert(range.clone(), *value);
}
let tuples = t.into_iter().collect::<Vec<(Range<usize>, usize)>>();
for window in tuples.windows(2) {
let a = Interval::from(window[0].0.clone());
let b = Interval::from(window[1].0.clone());
assert!(a < b);
}
let tuples = tuples
.into_iter()
.map(|(r, v)| (r.clone(), v))
.collect::<HashMap<_, _>>();
assert_eq!(tuples, values);
}
}
macro_rules! assert_pruning_stats {
(
$t:ident,
$ty:ty,
want_yield = $want_yield:literal,
want_visited = $want_visited:literal
) => {
paste::paste! {{
let n_filtered = Arc::new(AtomicUsize::new(0));
let iter = PruningIter::new(
$t.0.as_deref().unwrap(),
&Range { start: 42, end: 1042 },
NodeFilterCount::new($ty, Arc::clone(&n_filtered)),
);
let n_yielded = iter.count();
assert_eq!(n_yielded, $want_yield, "yield count differs for {}", stringify!($ty));
let n_filtered = n_filtered.load(std::sync::atomic::Ordering::Relaxed);
assert_eq!(n_filtered, $want_visited, "visited count differs for {}", stringify!($ty));
}}
};
}
#[test]
fn test_pruning_effectiveness() {
const N: usize = (u16::MAX as usize - 1) / 2;
let mut t: IntervalTree<_, usize> = IntervalTree::default();
let mut a = Lfsr::new(42);
let mut b = Lfsr::new(24);
for i in 0..N {
let a = a.next();
let b = b.next();
let r = Range {
start: a.min(b),
end: a.max(b),
};
t.insert(r, i);
}
assert_pruning_stats!(t, OverlapsPruner, want_yield = 1043, want_visited = 1044);
assert_pruning_stats!(t, PrecedesPruner, want_yield = 1, want_visited = 49);
assert_pruning_stats!(
t,
PrecededByPruner,
want_yield = 31722,
want_visited = 32759
);
assert_pruning_stats!(t, MeetsPruner, want_yield = 0, want_visited = 49);
assert_pruning_stats!(t, MetByPruner, want_yield = 1, want_visited = 32759);
assert_pruning_stats!(t, StartsPruner, want_yield = 0, want_visited = 49);
assert_pruning_stats!(t, FinishesPruner, want_yield = 0, want_visited = 32759);
assert_pruning_stats!(t, DuringPruner, want_yield = 24, want_visited = 1045);
assert_pruning_stats!(t, ContainsPruner, want_yield = 48, want_visited = 49);
}
macro_rules! test_algebraic_iter {
($name:tt) => {
paste::paste! {
proptest! {
#[test]
fn [<prop_algebraic_iter_ $name>](
query in arbitrary_range().prop_filter("invalid query interval", is_sane_interval),
values in prop::collection::vec(
arbitrary_range(),
0..10
),
) {
let control = values
.iter()
.filter(|&v| is_sane_interval(v))
.filter(|&v| Interval::from(v.clone()).$name(&query))
.collect::<HashSet<_>>();
let mut t = IntervalTree::default();
for range in &values {
t.insert(range.clone(), 42);
}
let got = t
.[<iter_ $name>](&query)
.map(|v| v.0)
.filter(|&v| is_sane_interval(v))
.collect::<HashSet<_>>();
assert_eq!(got, control);
}
}
}
};
}
fn is_sane_interval<R>(r: &Range<R>) -> bool
where
R: Ord,
{
r.start <= r.end
}
test_algebraic_iter!(overlaps);
test_algebraic_iter!(precedes);
test_algebraic_iter!(preceded_by);
test_algebraic_iter!(meets);
test_algebraic_iter!(met_by);
test_algebraic_iter!(starts);
test_algebraic_iter!(finishes);
test_algebraic_iter!(during);
test_algebraic_iter!(contains);
fn validate_tree_structure<R, V>(t: &IntervalTree<R, V>)
where
R: Ord + PartialEq + Debug + Clone,
V: Debug,
{
let root = match t.0.as_deref() {
Some(v) => v,
None => return,
};
let tree_max = t.max_interval_end();
let mut nodes_max = None;
let mut stack = vec![root];
while let Some(n) = stack.pop() {
stack.extend(n.left().iter().chain(n.right().iter()));
assert!(n
.left()
.map(|v| v.interval() < n.interval())
.unwrap_or(true));
assert!(n
.right()
.map(|v| v.interval() > n.interval())
.unwrap_or(true));
let left_height = n.left().map(|v| v.height());
let right_height = n.right().map(|v| v.height());
let want_height = left_height
.max(right_height)
.map(|v| v + 1) .unwrap_or_default();
assert_eq!(
n.height(),
want_height,
"expect node with interval {:?} to have height {}, has {}",
n.interval(),
want_height,
n.height(),
);
let balance = left_height
.and_then(|l| right_height.map(|r| l as i64 - r as i64))
.unwrap_or_default()
.abs();
assert!(
balance <= 1,
"balance={balance}, node={n:?}, stack={stack:?}"
);
let child_max = n
.left()
.map(|v| v.subtree_max())
.max(n.right().map(|v| v.subtree_max()));
let want_max = child_max.max(Some(n.interval().end())).unwrap();
assert_eq!(want_max, n.subtree_max());
if nodes_max.is_none() {
nodes_max = Some(want_max);
} else {
nodes_max = nodes_max.max(Some(want_max));
}
}
assert_eq!(tree_max, nodes_max);
}
#[allow(dead_code)]
mod pruning_iter_lifetime {
use super::*;
macro_rules! compile_test {
($fn:ident) => {
fn $fn<'a, R: Ord + Clone + Debug, V>(
t: &'a IntervalTree<R, V>,
range: Range<R>,
) -> Vec<&'a V> {
t.$fn(&range).map(|(_, x)| x).collect()
}
};
}
compile_test!(iter_overlaps);
compile_test!(iter_precedes);
compile_test!(iter_preceded_by);
compile_test!(iter_meets);
compile_test!(iter_met_by);
compile_test!(iter_starts);
compile_test!(iter_finishes);
compile_test!(iter_during);
compile_test!(iter_contains);
}
}