1use std::collections::HashSet;
2
3#[derive(Debug, Clone)]
21pub struct IntervalTree<T: Clone + Eq + std::hash::Hash> {
22 root: Option<Box<Node<T>>>,
23 size: usize,
24}
25
26#[derive(Debug, Clone)]
27struct Node<T: Clone + Eq + std::hash::Hash> {
28 low: u32,
30 high: u32,
31 max_high: u32,
33 values: HashSet<T>,
35 left: Option<Box<Node<T>>>,
37 right: Option<Box<Node<T>>>,
39}
40
41impl<T: Clone + Eq + std::hash::Hash> IntervalTree<T> {
42 pub fn new() -> Self {
44 Self {
45 root: None,
46 size: 0,
47 }
48 }
49
50 pub fn insert(&mut self, low: u32, high: u32, value: T) {
52 if let Some(root) = &mut self.root {
53 if Self::insert_into_node(root, low, high, value) {
54 self.size += 1;
55 }
56 } else {
57 let mut values = HashSet::new();
58 values.insert(value);
59 self.root = Some(Box::new(Node {
60 low,
61 high,
62 max_high: high,
63 values,
64 left: None,
65 right: None,
66 }));
67 self.size = 1;
68 }
69 }
70
71 fn insert_into_node(node: &mut Box<Node<T>>, low: u32, high: u32, value: T) -> bool {
73 if high > node.max_high {
75 node.max_high = high;
76 }
77
78 if low == node.low && high == node.high {
80 node.values.insert(value);
82 return false; }
84
85 if low < node.low {
87 if let Some(left) = &mut node.left {
88 Self::insert_into_node(left, low, high, value)
89 } else {
90 let mut values = HashSet::new();
91 values.insert(value);
92 node.left = Some(Box::new(Node {
93 low,
94 high,
95 max_high: high,
96 values,
97 left: None,
98 right: None,
99 }));
100 true
101 }
102 } else if let Some(right) = &mut node.right {
103 Self::insert_into_node(right, low, high, value)
104 } else {
105 let mut values = HashSet::new();
106 values.insert(value);
107 node.right = Some(Box::new(Node {
108 low,
109 high,
110 max_high: high,
111 values,
112 left: None,
113 right: None,
114 }));
115 true
116 }
117 }
118
119 pub fn remove(&mut self, low: u32, high: u32, value: &T) -> bool {
121 if let Some(root) = &mut self.root {
122 Self::remove_from_node(root, low, high, value)
123 } else {
124 false
125 }
126 }
127
128 fn remove_from_node(node: &mut Box<Node<T>>, low: u32, high: u32, value: &T) -> bool {
129 if low == node.low && high == node.high {
130 return node.values.remove(value);
131 }
132
133 if low < node.low {
134 if let Some(left) = &mut node.left {
135 return Self::remove_from_node(left, low, high, value);
136 }
137 } else if let Some(right) = &mut node.right {
138 return Self::remove_from_node(right, low, high, value);
139 }
140
141 false
142 }
143
144 pub fn query(&self, query_low: u32, query_high: u32) -> Vec<(u32, u32, HashSet<T>)> {
146 let mut results = Vec::new();
147 if let Some(root) = &self.root {
148 Self::query_node(root, query_low, query_high, &mut results);
149 }
150 results
151 }
152
153 fn query_node(
154 node: &Node<T>,
155 query_low: u32,
156 query_high: u32,
157 results: &mut Vec<(u32, u32, HashSet<T>)>,
158 ) {
159 if node.low <= query_high && node.high >= query_low {
161 results.push((node.low, node.high, node.values.clone()));
162 }
163
164 if let Some(left) = &node.left {
166 if left.max_high >= query_low {
168 Self::query_node(left, query_low, query_high, results);
169 }
170 }
171
172 if let Some(right) = &node.right {
174 if query_high >= node.low {
176 Self::query_node(right, query_low, query_high, results);
177 }
178 }
179 }
180
181 pub fn get_mut(&mut self, low: u32, high: u32) -> Option<&mut HashSet<T>> {
183 if let Some(root) = &mut self.root {
184 Self::get_mut_in_node(root, low, high)
185 } else {
186 None
187 }
188 }
189
190 fn get_mut_in_node(node: &mut Box<Node<T>>, low: u32, high: u32) -> Option<&mut HashSet<T>> {
191 if low == node.low && high == node.high {
192 return Some(&mut node.values);
193 }
194
195 if low < node.low {
196 if let Some(left) = &mut node.left {
197 return Self::get_mut_in_node(left, low, high);
198 }
199 } else if let Some(right) = &mut node.right {
200 return Self::get_mut_in_node(right, low, high);
201 }
202
203 None
204 }
205
206 pub fn is_empty(&self) -> bool {
208 self.root.is_none()
209 }
210
211 pub fn len(&self) -> usize {
213 self.size
214 }
215
216 pub fn clear(&mut self) {
218 self.root = None;
219 self.size = 0;
220 }
221
222 pub fn entry(&mut self, low: u32, high: u32) -> Entry<'_, T> {
224 Entry {
225 tree: self,
226 low,
227 high,
228 }
229 }
230
231 pub fn bulk_build_points(&mut self, mut items: Vec<(u32, std::collections::HashSet<T>)>) {
234 if self.root.is_some() {
235 for (k, set) in items.into_iter() {
237 for v in set {
238 self.insert(k, k, v);
239 }
240 }
241 return;
242 }
243 if items.is_empty() {
244 return;
245 }
246 items.sort_by_key(|(k, _)| *k);
248 let mut dedup: Vec<(u32, std::collections::HashSet<T>)> = Vec::with_capacity(items.len());
250 for (k, set) in items.into_iter() {
251 if let Some(last) = dedup.last_mut() {
252 if last.0 == k {
253 last.1.extend(set);
254 continue;
255 }
256 }
257 dedup.push((k, set));
258 }
259 fn build_balanced<T: Clone + Eq + std::hash::Hash>(
260 slice: &[(u32, std::collections::HashSet<T>)],
261 ) -> Option<Box<Node<T>>> {
262 if slice.is_empty() {
263 return None;
264 }
265 let mid = slice.len() / 2;
266 let (low, values) = (&slice[mid].0, &slice[mid].1);
267 let left = build_balanced(&slice[..mid]);
268 let right = build_balanced(&slice[mid + 1..]);
269 let mut max_high = *low;
271 if let Some(ref l) = left {
272 if l.max_high > max_high {
273 max_high = l.max_high;
274 }
275 }
276 if let Some(ref r) = right {
277 if r.max_high > max_high {
278 max_high = r.max_high;
279 }
280 }
281 Some(Box::new(Node {
282 low: *low,
283 high: *low,
284 max_high,
285 values: values.clone(),
286 left,
287 right,
288 }))
289 }
290 self.size = dedup.len();
291 self.root = build_balanced(&dedup);
292 }
293}
294
295impl<T: Clone + Eq + std::hash::Hash> Default for IntervalTree<T> {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301pub struct Entry<'a, T: Clone + Eq + std::hash::Hash> {
303 tree: &'a mut IntervalTree<T>,
304 low: u32,
305 high: u32,
306}
307
308impl<'a, T: Clone + Eq + std::hash::Hash> Entry<'a, T> {
309 pub fn or_insert_with<F>(self, f: F) -> &'a mut HashSet<T>
311 where
312 F: FnOnce() -> HashSet<T>,
313 {
314 if self.tree.get_mut(self.low, self.high).is_none() {
316 if let Some(root) = &mut self.tree.root {
318 Self::ensure_interval_exists(root, self.low, self.high);
319 } else {
320 self.tree.root = Some(Box::new(Node {
321 low: self.low,
322 high: self.high,
323 max_high: self.high,
324 values: f(),
325 left: None,
326 right: None,
327 }));
328 self.tree.size = 1;
329 }
330 }
331
332 self.tree.get_mut(self.low, self.high).unwrap()
333 }
334
335 fn ensure_interval_exists(node: &mut Box<Node<T>>, low: u32, high: u32) {
336 if high > node.max_high {
337 node.max_high = high;
338 }
339
340 if low == node.low && high == node.high {
341 return;
342 }
343
344 if low < node.low {
345 if let Some(left) = &mut node.left {
346 Self::ensure_interval_exists(left, low, high);
347 } else {
348 node.left = Some(Box::new(Node {
349 low,
350 high,
351 max_high: high,
352 values: HashSet::new(),
353 left: None,
354 right: None,
355 }));
356 }
357 } else if let Some(right) = &mut node.right {
358 Self::ensure_interval_exists(right, low, high);
359 } else {
360 node.right = Some(Box::new(Node {
361 low,
362 high,
363 max_high: high,
364 values: HashSet::new(),
365 left: None,
366 right: None,
367 }));
368 }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_insert_and_query_point_interval() {
378 let mut tree = IntervalTree::new();
379 tree.insert(5, 5, 100);
380
381 let results = tree.query(5, 5);
382 assert_eq!(results.len(), 1);
383 assert_eq!(results[0].0, 5);
384 assert_eq!(results[0].1, 5);
385 assert!(results[0].2.contains(&100));
386 }
387
388 #[test]
389 fn test_insert_and_query_range() {
390 let mut tree = IntervalTree::new();
391 tree.insert(10, 20, 1);
392 tree.insert(15, 25, 2);
393 tree.insert(30, 40, 3);
394
395 let results = tree.query(12, 22);
397 assert_eq!(results.len(), 2);
398
399 let results = tree.query(35, 45);
401 assert_eq!(results.len(), 1);
402 assert!(results[0].2.contains(&3));
403 }
404
405 #[test]
406 fn test_remove_value() {
407 let mut tree = IntervalTree::new();
408 tree.insert(5, 5, 100);
409 tree.insert(5, 5, 200);
410
411 assert_eq!(tree.query(5, 5).len(), 1);
412 assert_eq!(tree.query(5, 5)[0].2.len(), 2);
413
414 tree.remove(5, 5, &100);
415
416 let results = tree.query(5, 5);
417 assert_eq!(results.len(), 1);
418 assert_eq!(results[0].2.len(), 1);
419 assert!(results[0].2.contains(&200));
420 }
421
422 #[test]
423 fn test_entry_api() {
424 let mut tree: IntervalTree<i32> = IntervalTree::new();
425
426 tree.entry(10, 10).or_insert_with(HashSet::new).insert(42);
427
428 tree.entry(10, 10).or_insert_with(HashSet::new).insert(43);
429
430 let results = tree.query(10, 10);
431 assert_eq!(results.len(), 1);
432 assert_eq!(results[0].2.len(), 2);
433 assert!(results[0].2.contains(&42));
434 assert!(results[0].2.contains(&43));
435 }
436
437 #[test]
438 fn test_large_sparse_tree() {
439 let mut tree = IntervalTree::new();
440
441 for i in (0..1_000_000).step_by(10000) {
443 tree.insert(i, i, i as i32);
444 }
445
446 assert_eq!(tree.len(), 100);
447
448 let results = tree.query(500_000, u32::MAX);
450 assert_eq!(results.len(), 50);
451 }
452}