1use std::collections::{BTreeMap, HashSet};
2
3#[derive(Debug, Clone)]
22struct IntervalNode<T: Clone + Eq + std::hash::Hash> {
23 high: u32,
24 values: HashSet<T>,
25}
26
27#[derive(Debug, Clone)]
29pub struct IntervalTree<T: Clone + Eq + std::hash::Hash> {
30 map: BTreeMap<u32, Vec<IntervalNode<T>>>,
33 size: usize,
34}
35
36impl<T: Clone + Eq + std::hash::Hash> Default for IntervalTree<T> {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl<T: Clone + Eq + std::hash::Hash> IntervalTree<T> {
43 pub fn new() -> Self {
44 Self {
45 map: BTreeMap::new(),
46 size: 0,
47 }
48 }
49
50 pub fn len(&self) -> usize {
51 self.size
52 }
53
54 pub fn is_empty(&self) -> bool {
55 self.size == 0
56 }
57
58 pub fn get_mut(&mut self, low: u32, high: u32) -> Option<&mut HashSet<T>> {
61 self.map.get_mut(&low).and_then(|nodes| {
62 nodes
63 .iter_mut()
64 .find(|n| n.high == high)
65 .map(|n| &mut n.values)
66 })
67 }
68
69 pub fn insert(&mut self, low: u32, high: u32, value: T) {
71 let entries = self.map.entry(low).or_default();
72
73 if let Some(node) = entries.iter_mut().find(|n| n.high == high) {
74 node.values.insert(value);
75 } else {
76 let mut values = HashSet::new();
77 values.insert(value);
78 entries.push(IntervalNode { high, values });
79 self.size += 1;
80 }
81 }
82
83 pub fn query(&self, q_low: u32, q_high: u32) -> Vec<(u32, u32, HashSet<T>)> {
84 let mut results = Vec::new();
85 for (&low, nodes) in self.map.range(..=q_high) {
86 for node in nodes {
87 if node.high >= q_low {
88 results.push((low, node.high, node.values.clone()));
89 }
90 }
91 }
92 results
93 }
94
95 pub fn remove(&mut self, low: u32, high: u32, value: &T) -> bool {
96 if let Some(nodes) = self.map.get_mut(&low)
97 && let Some(node) = nodes.iter_mut().find(|n| n.high == high)
98 {
99 let removed = node.values.remove(value);
100
101 if removed && node.values.is_empty() {
102 nodes.retain(|n| n.high != high);
103 self.size -= 1;
104 if nodes.is_empty() {
105 self.map.remove(&low);
106 }
107 }
108 return removed;
109 }
110 false
111 }
112
113 pub fn entry(&mut self, low: u32, high: u32) -> BTreeEntry<'_, T> {
114 BTreeEntry {
115 tree: self,
116 low,
117 high,
118 }
119 }
120
121 pub fn bulk_build_points(&mut self, mut items: Vec<(u32, HashSet<T>)>) {
123 if !self.is_empty() {
124 for (coord, set) in items {
126 for val in set {
127 self.insert(coord, coord, val);
128 }
129 }
130 return;
131 }
132
133 if items.is_empty() {
134 return;
135 }
136
137 items.sort_by_key(|(k, _)| *k);
139
140 for (coord, set) in items {
142 let entries = self.map.entry(coord).or_default();
143
144 if let Some(node) = entries.iter_mut().find(|n| n.high == coord) {
146 node.values.extend(set);
147 } else {
148 entries.push(IntervalNode {
149 high: coord,
150 values: set,
151 });
152 self.size += 1;
153 }
154 }
155 }
156}
157
158pub struct BTreeEntry<'a, T: Clone + Eq + std::hash::Hash> {
159 tree: &'a mut IntervalTree<T>,
160 low: u32,
161 high: u32,
162}
163
164impl<'a, T: Clone + Eq + std::hash::Hash> BTreeEntry<'a, T> {
165 pub fn or_insert_with<F>(self, f: F) -> &'a mut HashSet<T>
166 where
167 F: FnOnce() -> HashSet<T>,
168 {
169 if self.tree.get_mut(self.low, self.high).is_none() {
170 let values = f();
171 let entries = self.tree.map.entry(self.low).or_default();
172 entries.push(IntervalNode {
173 high: self.high,
174 values,
175 });
176 self.tree.size += 1;
177 }
178 self.tree.get_mut(self.low, self.high).unwrap()
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_insert_and_query_point_interval() {
188 let mut tree = IntervalTree::new();
189 tree.insert(5, 5, 100);
190
191 let results = tree.query(5, 5);
192 assert_eq!(results.len(), 1);
193 assert_eq!(results[0].0, 5);
194 assert_eq!(results[0].1, 5);
195 assert!(results[0].2.contains(&100));
196 }
197
198 #[test]
199 fn test_insert_and_query_range() {
200 let mut tree = IntervalTree::new();
201 tree.insert(10, 20, 1);
202 tree.insert(15, 25, 2);
203 tree.insert(30, 40, 3);
204
205 let results = tree.query(12, 22);
207 assert_eq!(results.len(), 2);
208
209 let results = tree.query(35, 45);
211 assert_eq!(results.len(), 1);
212 assert!(results[0].2.contains(&3));
213 }
214
215 #[test]
216 fn test_remove_value() {
217 let mut tree = IntervalTree::new();
218 tree.insert(5, 5, 100);
219 tree.insert(5, 5, 200);
220
221 assert_eq!(tree.query(5, 5).len(), 1);
222 assert_eq!(tree.query(5, 5)[0].2.len(), 2);
223
224 tree.remove(5, 5, &100);
225
226 let results = tree.query(5, 5);
227 assert_eq!(results.len(), 1);
228 assert_eq!(results[0].2.len(), 1);
229 assert!(results[0].2.contains(&200));
230 }
231
232 #[test]
233 fn test_entry_api() {
234 let mut tree: IntervalTree<i32> = IntervalTree::new();
235
236 tree.entry(10, 10).or_insert_with(HashSet::new).insert(42);
237
238 tree.entry(10, 10).or_insert_with(HashSet::new).insert(43);
239
240 let results = tree.query(10, 10);
241 assert_eq!(results.len(), 1);
242 assert_eq!(results[0].2.len(), 2);
243 assert!(results[0].2.contains(&42));
244 assert!(results[0].2.contains(&43));
245 }
246
247 #[test]
248 fn test_large_sparse_tree() {
249 let mut tree = IntervalTree::new();
250
251 for i in (0..1_000_000).step_by(10000) {
253 tree.insert(i, i, i as i32);
254 }
255
256 assert_eq!(tree.len(), 100);
257
258 let results = tree.query(500_000, u32::MAX);
260 assert_eq!(results.len(), 50);
261 }
262
263 #[test]
264 fn test_entry_recursion_bug() {
265 let mut tree: IntervalTree<u32> = IntervalTree::new();
266
267 let count: u32 = 5000;
270 for i in 0..count {
271 tree.entry(i, i).or_insert_with(HashSet::new);
272 }
273
274 assert_eq!(tree.len(), count as usize);
275 }
276
277 #[test]
278 fn test_complex_overlaps() {
279 let mut tree = IntervalTree::new();
280 tree.insert(10, 100, "A");
282 tree.insert(20, 50, "B");
283 tree.insert(30, 40, "C");
284
285 tree.insert(5, 15, "D");
287 tree.insert(95, 105, "E");
288
289 let results = tree.query(35, 35);
291 assert_eq!(results.len(), 3); let results = tree.query(98, 102);
295 assert_eq!(results.len(), 2); }
297
298 #[test]
299 fn test_multiple_values_and_size() {
300 let mut tree = IntervalTree::new();
301
302 tree.insert(10, 10, "val1");
304 tree.insert(10, 10, "val2");
305 assert_eq!(tree.len(), 1); tree.insert(10, 10, "val1");
309 assert_eq!(tree.len(), 1);
310 let results = tree.query(10, 10);
311 assert_eq!(results[0].2.len(), 2); }
313
314 #[test]
315 fn test_remove_edge_cases() {
316 let mut tree = IntervalTree::new();
317 tree.insert(10, 20, "A");
318
319 let removed = tree.remove(10, 20, &"B");
321 assert!(!removed);
322 assert_eq!(tree.query(10, 20)[0].2.len(), 1);
323
324 let removed = tree.remove(99, 100, &"A");
326 assert!(!removed);
327 }
328
329 #[test]
330 fn test_bulk_build_consistency() {
331 let mut incremental_tree = IntervalTree::new();
332 let mut bulk_tree = IntervalTree::new();
333
334 let data: Vec<(u32, HashSet<&str>)> = vec![
335 (10, vec!["A", "B"].into_iter().collect()),
336 (20, vec!["C"].into_iter().collect()),
337 (5, vec!["D"].into_iter().collect()),
338 ];
339
340 for (coord, values) in &data {
342 for val in values {
343 incremental_tree.insert(*coord, *coord, *val);
344 }
345 }
346
347 bulk_tree.bulk_build_points(data.clone());
349
350 assert_eq!(incremental_tree.len(), bulk_tree.len());
352 assert_eq!(incremental_tree.query(0, 100), bulk_tree.query(0, 100));
353 }
354
355 #[test]
356 fn test_query_stack_safety() {
357 let mut tree = IntervalTree::new();
358 let count = 10_000;
359
360 for i in 0..count {
362 tree.insert(i, i, i);
363 }
364
365 let results = tree.query(count - 1, count - 1);
368 assert_eq!(results.len(), 1);
369 }
370
371 #[test]
372 fn test_empty_and_boundaries() {
373 let mut tree: IntervalTree<i32> = IntervalTree::new();
374
375 assert!(tree.is_empty());
376 assert_eq!(tree.query(0, 100).len(), 0);
377 assert!(!tree.remove(0, 0, &1));
378
379 tree.insert(50, 60, 1);
381 assert_eq!(tree.query(0, 49).len(), 0);
382 assert_eq!(tree.query(61, 100).len(), 0);
383 }
384
385 #[test]
386 fn test_multi_value_interval_size_tracking() {
387 let mut tree = IntervalTree::new();
388 let iv = (10, 20);
389
390 tree.insert(iv.0, iv.1, "A");
393 tree.insert(iv.0, iv.1, "B");
394 assert_eq!(tree.len(), 1, "Should be 1 unique interval");
395
396 assert!(tree.remove(iv.0, iv.1, &"A"));
398 assert_eq!(
399 tree.len(),
400 1,
401 "Should still be 1 interval after partial removal"
402 );
403
404 assert!(tree.remove(iv.0, iv.1, &"B"));
406 assert_eq!(tree.len(), 0, "Should be 0 after last value removed");
407 }
408}