lk_math/
interval_set.rs

1use std::{
2    iter::Sum,
3    ops::{Add, Sub},
4};
5
6use crate::interval::UniversalInterval;
7
8use super::interval::{ExclusiveMax, InclusiveMin, Interval};
9
10/// Disjoint set of intervals.
11///
12/// `T` must implement `Copy` and `Ord`.
13///
14/// Because of the `Ord` constraint, floating point types are not supported.
15/// This can be worked around by creating a wrapper type that implements `Ord`.
16/// Wrappers `OrdF32` and `OrdF64` are provided in the `ord_float` module.
17#[derive(Debug, PartialEq, Eq)]
18pub struct IntervalSet<T> {
19    pub intervals: Vec<std::ops::Range<T>>,
20}
21
22impl<T: Copy + Ord> IntervalSet<T> {
23    pub fn new() -> Self {
24        Self { intervals: vec![] }
25    }
26
27    pub fn intersect(&mut self, interval: std::ops::Range<T>) {
28        self.intervals = self
29            .intervals
30            .iter()
31            .filter_map(|x| x.intersection(&interval))
32            .collect();
33    }
34
35    /// Remove all intervals that do not intersect with the given interval.
36    pub fn retain_intersecting(&mut self, interval: std::ops::Range<T>) {
37        self.intervals = self
38            .intervals
39            .iter()
40            .filter(|x| x.intersection(&interval).is_some())
41            .cloned()
42            .collect();
43    }
44
45    pub fn union(&mut self, interval: std::ops::Range<T>) {
46        if *interval.inclusive_min() >= *interval.exclusive_max() {
47            return;
48        }
49
50        if self.intervals.is_empty() {
51            self.intervals.push(interval);
52            return;
53        }
54
55        let index0 = match self
56            .intervals
57            .binary_search_by(|x| x.inclusive_min().cmp(interval.inclusive_min()))
58        {
59            Ok(value) => value,
60            Err(value) => value,
61        };
62        let index1 = match self
63            .intervals
64            .binary_search_by(|x| x.exclusive_max().cmp(interval.exclusive_max()))
65        {
66            Ok(value) => value,
67            Err(value) => value,
68        };
69
70        if index0 > index1 {
71            // NOTE(lubo): Already included
72            return;
73        }
74
75        if index0 < index1 {
76            // NOTE(lubo): We can definitely remove n = (index1 - index0) segments.
77            // Segments to definitely remove:
78            //  1. index0
79            //  2. index0 + 1
80            //  ...
81            //  n. index0 + n - 1
82            self.intervals.drain(index0..index1);
83        }
84
85        // NOTE(lubo): Either
86        // 1. add new segment (+1 total)
87        // 2. join left segment
88        // 3. join right segment
89        // 4. join both (-1 total)
90        let index = index0;
91
92        if index > 0 {
93            let pre = self.intervals[index - 1].union(&interval);
94            if let Some(mut interval) = pre {
95                if index < self.intervals.len() {
96                    let all_three = self.intervals[index].union(&interval);
97                    if let Some(all_three) = all_three {
98                        interval = all_three;
99                        self.intervals.remove(index);
100                    }
101                }
102
103                self.intervals[index - 1] = interval;
104                return;
105            }
106        }
107
108        if index < self.intervals.len() {
109            let post = self.intervals[index].union(&interval);
110            if let Some(post) = post {
111                self.intervals[index] = post;
112                return;
113            }
114        }
115
116        self.intervals.insert(index, interval);
117    }
118}
119
120impl<T: Copy + Ord> Default for IntervalSet<T> {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl<T: Copy + Add<Output = T> + Sub<Output = T> + Sum> IntervalSet<T> {
127    pub fn measure(&self) -> T {
128        self.intervals
129            .iter()
130            .map(|x| *x.exclusive_max() - *x.inclusive_min())
131            .sum()
132    }
133
134    pub fn bounds(&self) -> Option<std::ops::Range<T>> {
135        let count = self.intervals.len();
136        if count > 0 {
137            Some(*self.intervals[0].inclusive_min()..*self.intervals[count - 1].exclusive_max())
138        } else {
139            None
140        }
141    }
142
143    /// Negation of the set of intervals.
144    ///
145    /// The negation of an empty set is the entire domain (the "universal interval").
146    /// This requires the notion of "most extreme values" for the type `T`.
147    /// For example, the most extreme values for `i32` are `i32::MIN` and `i32::MAX`.
148    /// For `f32`, the most extreme values would be `f32::NEG_INFINITY` and `f32::INFINITY`.
149    /// (Although `f32` cannot be used since it does not implement `Ord`. See [`crate::ord_float::OrdF32`].)
150    /// These bounds are defined in the [`UniversalInterval`] trait which is required for
151    /// this function.
152    ///
153    /// See [`negation_within_bounds`] for a version that does not require universal bounds.
154    pub fn negation(&self) -> Self
155    where
156        T: UniversalInterval,
157    {
158        let count = self.intervals.len();
159
160        if count > 0 {
161            let mut negated = vec![];
162
163            if !self.intervals[0].inclusive_min().is_infinum() {
164                negated.push(T::INFINUM..*self.intervals[0].inclusive_min());
165            }
166
167            for i in 0..count - 1 {
168                negated.push(
169                    *self.intervals[i].exclusive_max()..*self.intervals[i + 1].inclusive_min(),
170                )
171            }
172
173            if !self.intervals[count - 1].exclusive_max().is_supremum() {
174                negated.push(*self.intervals[count - 1].exclusive_max()..T::SUPREMUM);
175            }
176
177            Self { intervals: negated }
178        } else {
179            Self {
180                intervals: vec![T::universal_interval()],
181            }
182        }
183    }
184
185    pub fn negation_within_bounds(&self) -> Self {
186        let count = self.intervals.len();
187
188        if count > 0 {
189            let mut negated = vec![];
190
191            for i in 0..count - 1 {
192                negated.push(
193                    *self.intervals[i].exclusive_max()..*self.intervals[i + 1].inclusive_min(),
194                )
195            }
196
197            Self { intervals: negated }
198        } else {
199            Self { intervals: vec![] }
200        }
201    }
202}
203
204impl<T: Copy + Ord> IntervalSet<T> {
205    pub fn containing_interval(&self, value: &T) -> Option<std::ops::Range<T>> {
206        let index0 = match self
207            .intervals
208            .binary_search_by(|probe| probe.exclusive_max().cmp(value))
209        {
210            Ok(value) => value,
211            Err(value) => value,
212        };
213        if let Some(a) = self.intervals.get(index0) {
214            if a.contains(value) {
215                Some(a.clone())
216            } else {
217                None
218            }
219        } else {
220            None
221        }
222    }
223
224    pub fn contains(&self, value: &T) -> bool {
225        self.containing_interval(value).is_some()
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use crate::{
232        interval_set::IntervalSet,
233        ord_float::{OrdF32, OrdF64},
234    };
235
236    #[test]
237    fn empty() {
238        let set = IntervalSet::<i32>::new();
239        assert_eq!(set.measure(), 0);
240        assert_eq!(set.bounds(), None);
241        assert!(set.negation_within_bounds().intervals.is_empty());
242
243        assert_eq!(set.negation().intervals, vec![-2147483648..2147483647]);
244        assert_eq!(set.negation().negation(), set);
245
246        assert!(!set.contains(&i32::MIN));
247        assert!(!set.contains(&-1));
248        assert!(!set.contains(&0));
249        assert!(!set.contains(&1));
250        assert!(!set.contains(&i32::MAX));
251    }
252
253    #[test]
254    fn i32() {
255        let a = 0..2;
256        let b = 1..3;
257        let mut set = IntervalSet::new();
258        set.union(a);
259        set.union(b);
260        assert_eq!(set.measure(), 3);
261
262        assert_eq!(
263            set.negation().intervals,
264            vec![-2147483648..0, 3..2147483647]
265        );
266        assert_eq!(set.negation().negation(), set);
267
268        assert!(!set.contains(&i32::MIN));
269        assert!(!set.contains(&-1));
270        assert!(set.contains(&0));
271        assert!(set.contains(&1));
272        assert!(set.contains(&2));
273        assert!(!set.contains(&3));
274        assert!(!set.contains(&i32::MAX));
275    }
276
277    #[test]
278    fn f32() {
279        let a = OrdF32(0.0)..OrdF32(2.0);
280        let b = OrdF32(1.0)..OrdF32(3.0);
281        let mut set = IntervalSet::new();
282        set.union(a);
283        set.union(b);
284        assert_eq!(*set.measure(), 3.0);
285
286        assert_eq!(
287            set.negation().intervals,
288            vec![
289                OrdF32(f32::NEG_INFINITY)..OrdF32(0.0),
290                OrdF32(3.0)..OrdF32(f32::INFINITY)
291            ]
292        );
293        assert_eq!(set.negation().negation(), set);
294
295        assert!(!set.contains(&OrdF32(f32::NEG_INFINITY)));
296        assert!(!set.contains(&OrdF32(f32::MIN)));
297        assert!(!set.contains(&OrdF32(-1.0)));
298        assert!(!set.contains(&OrdF32(-f32::EPSILON)));
299        assert!(set.contains(&OrdF32(0.0)));
300        assert!(set.contains(&OrdF32(1.0)));
301        assert!(set.contains(&OrdF32(2.0)));
302        assert!(set.contains(&OrdF32(2.999)));
303        assert!(!set.contains(&OrdF32(3.0)));
304        assert!(!set.contains(&OrdF32(f32::MAX)));
305        assert!(!set.contains(&OrdF32(f32::INFINITY)));
306    }
307
308    #[test]
309    fn f64() {
310        let a = OrdF64(0.0)..OrdF64(2.0);
311        let b = OrdF64(1.0)..OrdF64(3.0);
312        let mut set = IntervalSet::new();
313        set.union(a);
314        set.union(b);
315        assert_eq!(*set.measure(), 3.0);
316
317        assert_eq!(
318            set.negation().intervals,
319            vec![
320                OrdF64(f64::NEG_INFINITY)..OrdF64(0.0),
321                OrdF64(3.0)..OrdF64(f64::INFINITY)
322            ]
323        );
324        assert_eq!(set.negation().negation(), set);
325
326        assert!(!set.contains(&OrdF64(f64::NEG_INFINITY)));
327        assert!(!set.contains(&OrdF64(f64::MIN)));
328        assert!(!set.contains(&OrdF64(-1.0)));
329        assert!(!set.contains(&OrdF64(-f64::EPSILON)));
330        assert!(set.contains(&OrdF64(0.0)));
331        assert!(set.contains(&OrdF64(1.0)));
332        assert!(set.contains(&OrdF64(2.0)));
333        assert!(set.contains(&OrdF64(2.999)));
334        assert!(!set.contains(&OrdF64(3.0)));
335        assert!(!set.contains(&OrdF64(f64::MAX)));
336        assert!(!set.contains(&OrdF64(f64::INFINITY)));
337    }
338}