1use std::{
2 iter::Sum,
3 ops::{Add, Sub},
4};
5
6use crate::interval::UniversalInterval;
7
8use super::interval::{ExclusiveMax, InclusiveMin, Interval};
9
10#[derive(Debug, PartialEq, Eq)]
18pub struct IntervalSet<T> {
19 pub intervals: Vec<std::ops::Range<T>>,
20}
21
22impl<T: Copy + Ord> IntervalSet<T> {
23 pub fn new() -> Self {
24 Self { intervals: vec![] }
25 }
26
27 pub fn intersect(&mut self, interval: std::ops::Range<T>) {
28 self.intervals = self
29 .intervals
30 .iter()
31 .filter_map(|x| x.intersection(&interval))
32 .collect();
33 }
34
35 pub fn retain_intersecting(&mut self, interval: std::ops::Range<T>) {
37 self.intervals = self
38 .intervals
39 .iter()
40 .filter(|x| x.intersection(&interval).is_some())
41 .cloned()
42 .collect();
43 }
44
45 pub fn union(&mut self, interval: std::ops::Range<T>) {
46 if *interval.inclusive_min() >= *interval.exclusive_max() {
47 return;
48 }
49
50 if self.intervals.is_empty() {
51 self.intervals.push(interval);
52 return;
53 }
54
55 let index0 = match self
56 .intervals
57 .binary_search_by(|x| x.inclusive_min().cmp(interval.inclusive_min()))
58 {
59 Ok(value) => value,
60 Err(value) => value,
61 };
62 let index1 = match self
63 .intervals
64 .binary_search_by(|x| x.exclusive_max().cmp(interval.exclusive_max()))
65 {
66 Ok(value) => value,
67 Err(value) => value,
68 };
69
70 if index0 > index1 {
71 return;
73 }
74
75 if index0 < index1 {
76 self.intervals.drain(index0..index1);
83 }
84
85 let index = index0;
91
92 if index > 0 {
93 let pre = self.intervals[index - 1].union(&interval);
94 if let Some(mut interval) = pre {
95 if index < self.intervals.len() {
96 let all_three = self.intervals[index].union(&interval);
97 if let Some(all_three) = all_three {
98 interval = all_three;
99 self.intervals.remove(index);
100 }
101 }
102
103 self.intervals[index - 1] = interval;
104 return;
105 }
106 }
107
108 if index < self.intervals.len() {
109 let post = self.intervals[index].union(&interval);
110 if let Some(post) = post {
111 self.intervals[index] = post;
112 return;
113 }
114 }
115
116 self.intervals.insert(index, interval);
117 }
118}
119
120impl<T: Copy + Ord> Default for IntervalSet<T> {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl<T: Copy + Add<Output = T> + Sub<Output = T> + Sum> IntervalSet<T> {
127 pub fn measure(&self) -> T {
128 self.intervals
129 .iter()
130 .map(|x| *x.exclusive_max() - *x.inclusive_min())
131 .sum()
132 }
133
134 pub fn bounds(&self) -> Option<std::ops::Range<T>> {
135 let count = self.intervals.len();
136 if count > 0 {
137 Some(*self.intervals[0].inclusive_min()..*self.intervals[count - 1].exclusive_max())
138 } else {
139 None
140 }
141 }
142
143 pub fn negation(&self) -> Self
155 where
156 T: UniversalInterval,
157 {
158 let count = self.intervals.len();
159
160 if count > 0 {
161 let mut negated = vec![];
162
163 if !self.intervals[0].inclusive_min().is_infinum() {
164 negated.push(T::INFINUM..*self.intervals[0].inclusive_min());
165 }
166
167 for i in 0..count - 1 {
168 negated.push(
169 *self.intervals[i].exclusive_max()..*self.intervals[i + 1].inclusive_min(),
170 )
171 }
172
173 if !self.intervals[count - 1].exclusive_max().is_supremum() {
174 negated.push(*self.intervals[count - 1].exclusive_max()..T::SUPREMUM);
175 }
176
177 Self { intervals: negated }
178 } else {
179 Self {
180 intervals: vec![T::universal_interval()],
181 }
182 }
183 }
184
185 pub fn negation_within_bounds(&self) -> Self {
186 let count = self.intervals.len();
187
188 if count > 0 {
189 let mut negated = vec![];
190
191 for i in 0..count - 1 {
192 negated.push(
193 *self.intervals[i].exclusive_max()..*self.intervals[i + 1].inclusive_min(),
194 )
195 }
196
197 Self { intervals: negated }
198 } else {
199 Self { intervals: vec![] }
200 }
201 }
202}
203
204impl<T: Copy + Ord> IntervalSet<T> {
205 pub fn containing_interval(&self, value: &T) -> Option<std::ops::Range<T>> {
206 let index0 = match self
207 .intervals
208 .binary_search_by(|probe| probe.exclusive_max().cmp(value))
209 {
210 Ok(value) => value,
211 Err(value) => value,
212 };
213 if let Some(a) = self.intervals.get(index0) {
214 if a.contains(value) {
215 Some(a.clone())
216 } else {
217 None
218 }
219 } else {
220 None
221 }
222 }
223
224 pub fn contains(&self, value: &T) -> bool {
225 self.containing_interval(value).is_some()
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use crate::{
232 interval_set::IntervalSet,
233 ord_float::{OrdF32, OrdF64},
234 };
235
236 #[test]
237 fn empty() {
238 let set = IntervalSet::<i32>::new();
239 assert_eq!(set.measure(), 0);
240 assert_eq!(set.bounds(), None);
241 assert!(set.negation_within_bounds().intervals.is_empty());
242
243 assert_eq!(set.negation().intervals, vec![-2147483648..2147483647]);
244 assert_eq!(set.negation().negation(), set);
245
246 assert!(!set.contains(&i32::MIN));
247 assert!(!set.contains(&-1));
248 assert!(!set.contains(&0));
249 assert!(!set.contains(&1));
250 assert!(!set.contains(&i32::MAX));
251 }
252
253 #[test]
254 fn i32() {
255 let a = 0..2;
256 let b = 1..3;
257 let mut set = IntervalSet::new();
258 set.union(a);
259 set.union(b);
260 assert_eq!(set.measure(), 3);
261
262 assert_eq!(
263 set.negation().intervals,
264 vec![-2147483648..0, 3..2147483647]
265 );
266 assert_eq!(set.negation().negation(), set);
267
268 assert!(!set.contains(&i32::MIN));
269 assert!(!set.contains(&-1));
270 assert!(set.contains(&0));
271 assert!(set.contains(&1));
272 assert!(set.contains(&2));
273 assert!(!set.contains(&3));
274 assert!(!set.contains(&i32::MAX));
275 }
276
277 #[test]
278 fn f32() {
279 let a = OrdF32(0.0)..OrdF32(2.0);
280 let b = OrdF32(1.0)..OrdF32(3.0);
281 let mut set = IntervalSet::new();
282 set.union(a);
283 set.union(b);
284 assert_eq!(*set.measure(), 3.0);
285
286 assert_eq!(
287 set.negation().intervals,
288 vec![
289 OrdF32(f32::NEG_INFINITY)..OrdF32(0.0),
290 OrdF32(3.0)..OrdF32(f32::INFINITY)
291 ]
292 );
293 assert_eq!(set.negation().negation(), set);
294
295 assert!(!set.contains(&OrdF32(f32::NEG_INFINITY)));
296 assert!(!set.contains(&OrdF32(f32::MIN)));
297 assert!(!set.contains(&OrdF32(-1.0)));
298 assert!(!set.contains(&OrdF32(-f32::EPSILON)));
299 assert!(set.contains(&OrdF32(0.0)));
300 assert!(set.contains(&OrdF32(1.0)));
301 assert!(set.contains(&OrdF32(2.0)));
302 assert!(set.contains(&OrdF32(2.999)));
303 assert!(!set.contains(&OrdF32(3.0)));
304 assert!(!set.contains(&OrdF32(f32::MAX)));
305 assert!(!set.contains(&OrdF32(f32::INFINITY)));
306 }
307
308 #[test]
309 fn f64() {
310 let a = OrdF64(0.0)..OrdF64(2.0);
311 let b = OrdF64(1.0)..OrdF64(3.0);
312 let mut set = IntervalSet::new();
313 set.union(a);
314 set.union(b);
315 assert_eq!(*set.measure(), 3.0);
316
317 assert_eq!(
318 set.negation().intervals,
319 vec![
320 OrdF64(f64::NEG_INFINITY)..OrdF64(0.0),
321 OrdF64(3.0)..OrdF64(f64::INFINITY)
322 ]
323 );
324 assert_eq!(set.negation().negation(), set);
325
326 assert!(!set.contains(&OrdF64(f64::NEG_INFINITY)));
327 assert!(!set.contains(&OrdF64(f64::MIN)));
328 assert!(!set.contains(&OrdF64(-1.0)));
329 assert!(!set.contains(&OrdF64(-f64::EPSILON)));
330 assert!(set.contains(&OrdF64(0.0)));
331 assert!(set.contains(&OrdF64(1.0)));
332 assert!(set.contains(&OrdF64(2.0)));
333 assert!(set.contains(&OrdF64(2.999)));
334 assert!(!set.contains(&OrdF64(3.0)));
335 assert!(!set.contains(&OrdF64(f64::MAX)));
336 assert!(!set.contains(&OrdF64(f64::INFINITY)));
337 }
338}