bl_range_set/
discrete_range_set.rs

1use num_traits::PrimInt;
2
3type Error = Box<dyn std::error::Error>;
4type Result<T> = std::result::Result<T, Error>;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct DiscreteRange<T: PrimInt>(T, T);
8
9impl<T: PrimInt> DiscreteRange<T> {
10    pub fn new(start: T, end: T) -> Result<Self> {
11        if start > end {
12            return Err("Invalid range (negative size)".into());
13        }
14
15        Ok(DiscreteRange(start, end))
16    }
17
18    pub fn len(&self) -> T {
19        self.1 - self.0 + T::one()
20    }
21
22    pub fn contains(&self, value: T) -> bool {
23        value >= self.0 && value <= self.1
24    }
25
26    pub fn try_merge(&self, other: &Self) -> Result<Self> {
27        if self.1 < other.0 - T::one() || other.1 < self.0 - T::one() {
28            return Err("Disjoint ranges cannot be merged".into());
29        }
30
31        Ok(DiscreteRange(self.0.min(other.0), self.1.max(other.1)))
32    }
33}
34
35impl<T: PrimInt> From<(T, T)> for DiscreteRange<T> {
36    fn from(tuple: (T, T)) -> Self {
37        DiscreteRange::new(tuple.0, tuple.1).unwrap()
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct DiscreteRangeSet<T: PrimInt> {
43    ranges: Vec<DiscreteRange<T>>,
44}
45
46impl<T: PrimInt> DiscreteRangeSet<T> {
47    pub fn new() -> Self {
48        DiscreteRangeSet { ranges: Vec::new() }
49    }
50
51    fn binary_search_by_first(&self, value: T) -> std::result::Result<usize, usize> {
52        self.ranges
53            .binary_search_by(|r| r.0.partial_cmp(&value).unwrap())
54    }
55
56    fn binary_search_contained_range(&self, value: T) -> std::result::Result<usize, usize> {
57        self.ranges.binary_search_by(|r| {
58            if r.contains(value) {
59                std::cmp::Ordering::Equal
60            } else if value < r.0 {
61                std::cmp::Ordering::Greater
62            } else {
63                std::cmp::Ordering::Less
64            }
65        })
66    }
67
68    // Add range to the set in sorted order, merging overlapping ranges.
69    pub fn add_range<R: Into<DiscreteRange<T>>>(&mut self, range: R) {
70        let range = range.into();
71        if range.len() == T::zero() {
72            return;
73        }
74
75        if self.ranges.is_empty() {
76            self.ranges.push(range);
77            return;
78        }
79
80        let start_pos = match self.binary_search_contained_range(range.0) {
81            Ok(pos) => pos,
82            Err(0) => 0,
83            Err(pos) => {
84                if range.0 == self.ranges[pos - 1].1 + T::one() {
85                    pos - 1
86                } else {
87                    pos
88                }
89            }
90        };
91
92        let end_pos = match self.binary_search_contained_range(range.1) {
93            Ok(pos) => pos + 1,
94            Err(0) => 0,
95            Err(pos) => {
96                if pos != self.ranges.len() && range.1 + T::one() == self.ranges[pos].0 {
97                    pos + 1
98                } else {
99                    pos
100                }
101            }
102        };
103
104        if start_pos == end_pos {
105            self.ranges.insert(start_pos, range);
106        } else {
107            let new_start = self.ranges[start_pos].0.min(range.0);
108            let new_end = self.ranges[end_pos - 1].1.max(range.1);
109            self.ranges[start_pos].0 = new_start;
110            self.ranges[start_pos].1 = new_end;
111            self.ranges.drain(start_pos + 1..end_pos);
112        }
113    }
114
115    pub fn contains(&self, value: T) -> bool {
116        match self.binary_search_by_first(value) {
117            Ok(_) => true,
118            Err(0) => false,
119            Err(pos) => self.ranges[pos - 1].contains(value),
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_inclusive_range_new() {
130        assert!(DiscreteRange::new(1u32, 2).is_ok());
131        assert!(DiscreteRange::new(2u32, 1).is_err());
132    }
133
134    #[test]
135    fn test_inclusive_range_set_add_and_contains() {
136        let mut range_set = DiscreteRangeSet::new();
137        range_set.add_range(DiscreteRange::new(1u32, 3).unwrap());
138        range_set.add_range(DiscreteRange::new(5u32, 7).unwrap());
139
140        assert_eq!(range_set.ranges.len(), 2);
141
142        let in_set = [1, 2, 3, 5, 6, 7];
143
144        for &value in &in_set {
145            assert!(range_set.contains(value));
146        }
147
148        let not_in_set = [0, 4, 8];
149        for &value in &not_in_set {
150            assert!(!range_set.contains(value));
151        }
152
153        range_set.add_range(DiscreteRange::new(4, 4).unwrap());
154        assert_eq!(range_set.ranges.len(), 1);
155
156        assert!(range_set.contains(1));
157        assert!(range_set.contains(2));
158        assert!(range_set.contains(3));
159        assert!(range_set.contains(4));
160        assert!(range_set.contains(5));
161        assert!(range_set.contains(6));
162        assert!(range_set.contains(7));
163
164        assert!(!range_set.contains(0));
165        assert!(!range_set.contains(8));
166    }
167
168    #[test]
169    fn test_inclusive_range_max_value() {
170        let max = u32::MAX;
171        let mut range_set = DiscreteRangeSet::new();
172
173        assert!(!range_set.contains(max));
174
175        range_set.add_range(DiscreteRange::new(max, max).unwrap());
176
177        assert!(range_set.contains(max));
178        assert!(!range_set.contains(max - 1));
179        assert!(!range_set.contains(u32::MIN));
180    }
181}