bl_range_set/
range_set.rs

1
2use super::number::Number;
3
4type Error = Box<dyn std::error::Error>;
5type Result<T> = std::result::Result<T, Error>;
6
7/// A half-open range [start, end)
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct Range<T: Number>(T, T);
10
11impl<T: Number> Range<T> {
12    pub fn new(start: T, end: T) -> Result<Self> {
13        if start.is_nan() || end.is_nan() {
14            return Err("NaN is not allowed in range".into());
15        }
16
17        if start > end {
18            return Err("Invalid range (negative size)".into());
19        }
20
21        Ok(Range(start, end))
22    }
23
24    pub fn len(&self) -> T {
25        self.1 - self.0
26    }
27
28    pub fn contains(&self, value: T) -> bool {
29        value >= self.0 && value < self.1
30    }
31
32    pub fn try_merge(&self, other: &Self) -> Result<Self> {
33        if self.1 < other.0 || other.1 < self.0 {
34            return Err("Disjoint ranges cannot be merged".into());
35        }
36
37        Ok(Range(self.0.min(other.0), self.1.max(other.1)))
38    }
39}
40
41impl<T: Number> From<(T, T)> for Range<T> {
42    fn from(tuple: (T, T)) -> Self {
43        Range::new(tuple.0, tuple.1).unwrap()
44    }
45}
46
47/// A set of non-overlapping, sorted ranges. Note that T::MAX cannot be included,
48/// due to the half-open nature of the ranges.
49#[derive(Debug, Clone, PartialEq)]
50pub struct RangeSet<T: Number> {
51    ranges: Vec<Range<T>>,
52}
53
54impl<T: Number> RangeSet<T> {
55    pub fn new() -> Self {
56        RangeSet { ranges: Vec::new() }
57    }
58
59    fn binary_search_by_first(&self, value: T) -> std::result::Result<usize, usize> {
60        self.ranges
61            .binary_search_by(|r| r.0.partial_cmp(&value).unwrap())
62    }
63
64    // Add range to the set in sorted order, merging overlapping ranges.
65    pub fn add_range<R: Into<Range<T>>>(&mut self, range: R) {
66        let range = range.into();
67        if range.len() == T::zero() {
68            return;
69        }
70
71        if self.ranges.is_empty() {
72            self.ranges.push(range);
73            return;
74        }
75
76        // Find the position to insert the new range.
77        let start_pos = match self.binary_search_by_first(range.0) {
78            Ok(pos) => pos,
79            Err(0) => 0,
80            Err(pos) => {
81                // Check for overlap with the previous range.
82                if range.0 <= self.ranges[pos - 1].1 {
83                    pos - 1
84                } else {
85                    pos
86                }
87            }
88        };
89
90        let end_pos = match self.binary_search_by_first(range.1){
91            Ok(pos) => pos + 1,
92            Err(pos) => pos,
93        };
94
95        if start_pos == end_pos {
96            self.ranges.insert(start_pos, range);
97        } else {
98            let new_start = self.ranges[start_pos].0.min(range.0);
99            let new_end = self.ranges[end_pos - 1].1.max(range.1);
100            self.ranges[start_pos].0 = new_start;
101            self.ranges[start_pos].1 = new_end;
102            self.ranges.drain(start_pos + 1..end_pos);
103        }
104    }
105
106    pub fn contains(&self, value: T) -> bool {
107        match self.binary_search_by_first(value) {
108            Ok(_) => true,
109            Err(0) => false,
110            Err(pos) => self.ranges[pos - 1].contains(value),
111        }
112    }
113
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_float_range_new() {
122        assert!(Range::new(1.0f64, 2.0).is_ok());
123        assert!(Range::new(2.0f64, 1.0).is_err());
124        assert!(Range::new(f64::NAN, 2.0).is_err());
125    }
126
127    #[test]
128    fn test_float_range_contains() {
129        let range = Range::new(1.0f64, 3.0).unwrap();
130        assert!(range.contains(1.0));
131        assert!(range.contains(2.0));
132        assert!(!range.contains(3.0));
133    }
134
135    #[test]
136    fn test_float_range_try_merge() {
137        let range1 = Range::new(1.0f64, 3.0).unwrap();
138        let range2 = Range::new(2.0f64, 4.0).unwrap();
139        let merged = range1.try_merge(&range2).unwrap();
140        assert_eq!(merged, Range::new(1.0, 4.0).unwrap());
141
142        let range3 = Range::new(4.0f64, 5.0).unwrap();
143        assert!(range1.try_merge(&range3).is_err());
144        assert!(range2.try_merge(&range3).is_ok());
145    }
146
147    #[test]
148    fn test_float_range_set_add_and_contains() {
149        let mut range_set = RangeSet::new();
150        range_set.add_range(Range::new(1.0f64, 3.0).unwrap());
151        range_set.add_range(Range::new(4.0f64, 7.0).unwrap());
152        range_set.add_range(Range::new(8.0f64, 10.0).unwrap());
153
154        let in_set = vec![1.0, 1.5, 2.0, 2.9, 4.0, 6.0, 8.0, 9.0];
155        for &value in &in_set {
156            assert!(range_set.contains(value));
157        }
158
159        let not_in_set = vec![-1.0, 0.0, 3.0, 3.5, 7.0, 7.5, 10.0, 11.0];
160        for &value in &not_in_set {
161            assert!(!range_set.contains(value));
162        }
163    }
164}