Skip to main content

formualizer_eval/engine/
interval_tree.rs

1use std::collections::{BTreeMap, HashSet};
2
3/// Custom interval tree optimized for spreadsheet cell indexing.
4///
5/// ## Design decisions:
6///
7/// 1. **Point intervals are the common case** - Most cells are single points [r,r] or [c,c]
8/// 2. **Sparse data** - Even million-row sheets typically have <10K cells
9/// 3. **Batch updates** - During shifts, we update many intervals at once
10/// 4. **Small value sets** - Each interval maps to a small set of VertexIds
11///
12/// ## Implementation:
13///
14/// Uses an augmented BST where each node stores:
15/// - Interval [low, high]
16/// - Max endpoint in subtree (for efficient pruning)
17/// - Value set (HashSet<VertexId>)
18///
19/// This is simpler than generic interval trees because we optimize for our specific use case.
20
21#[derive(Debug, Clone)]
22struct IntervalNode<T: Clone + Eq + std::hash::Hash> {
23    high: u32,
24    values: HashSet<T>,
25}
26
27/// B-Tree based implementation of the interval index.
28#[derive(Debug, Clone)]
29pub struct IntervalTree<T: Clone + Eq + std::hash::Hash> {
30    /// Maps low coordinate to a set of intervals/values starting there.
31    /// Internal storage uses IntervalNode, NOT Entry.
32    map: BTreeMap<u32, Vec<IntervalNode<T>>>,
33    size: usize,
34}
35
36impl<T: Clone + Eq + std::hash::Hash> Default for IntervalTree<T> {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl<T: Clone + Eq + std::hash::Hash> IntervalTree<T> {
43    pub fn new() -> Self {
44        Self {
45            map: BTreeMap::new(),
46            size: 0,
47        }
48    }
49
50    pub fn len(&self) -> usize {
51        self.size
52    }
53
54    pub fn is_empty(&self) -> bool {
55        self.size == 0
56    }
57
58    /// Get a mutable reference to the values for an exact interval match.
59    /// Required by the Entry API.
60    pub fn get_mut(&mut self, low: u32, high: u32) -> Option<&mut HashSet<T>> {
61        self.map.get_mut(&low).and_then(|nodes| {
62            nodes
63                .iter_mut()
64                .find(|n| n.high == high)
65                .map(|n| &mut n.values)
66        })
67    }
68
69    /// Insert a value for the given interval [low, high]
70    pub fn insert(&mut self, low: u32, high: u32, value: T) {
71        let entries = self.map.entry(low).or_default();
72
73        if let Some(node) = entries.iter_mut().find(|n| n.high == high) {
74            node.values.insert(value);
75        } else {
76            let mut values = HashSet::new();
77            values.insert(value);
78            entries.push(IntervalNode { high, values });
79            self.size += 1;
80        }
81    }
82
83    pub fn query(&self, q_low: u32, q_high: u32) -> Vec<(u32, u32, HashSet<T>)> {
84        let mut results = Vec::new();
85        for (&low, nodes) in self.map.range(..=q_high) {
86            for node in nodes {
87                if node.high >= q_low {
88                    results.push((low, node.high, node.values.clone()));
89                }
90            }
91        }
92        results
93    }
94
95    pub fn remove(&mut self, low: u32, high: u32, value: &T) -> bool {
96        if let Some(nodes) = self.map.get_mut(&low)
97            && let Some(node) = nodes.iter_mut().find(|n| n.high == high)
98        {
99            let removed = node.values.remove(value);
100
101            if removed && node.values.is_empty() {
102                nodes.retain(|n| n.high != high);
103                self.size -= 1;
104                if nodes.is_empty() {
105                    self.map.remove(&low);
106                }
107            }
108            return removed;
109        }
110        false
111    }
112
113    pub fn entry(&mut self, low: u32, high: u32) -> BTreeEntry<'_, T> {
114        BTreeEntry {
115            tree: self,
116            low,
117            high,
118        }
119    }
120
121    /// Bulk build optimization for a collection of point intervals [x,x].
122    pub fn bulk_build_points(&mut self, mut items: Vec<(u32, HashSet<T>)>) {
123        if !self.is_empty() {
124            // Fallback: incremental insert to preserve existing nodes
125            for (coord, set) in items {
126                for val in set {
127                    self.insert(coord, coord, val);
128                }
129            }
130            return;
131        }
132
133        if items.is_empty() {
134            return;
135        }
136
137        // 1. Sort by coordinate
138        items.sort_by_key(|(k, _)| *k);
139
140        // 2. Process items. BTreeMap handles the balancing (O(log N)).
141        for (coord, set) in items {
142            let entries = self.map.entry(coord).or_default();
143
144            // Since this is specifically for point intervals, check if [coord, coord] exists
145            if let Some(node) = entries.iter_mut().find(|n| n.high == coord) {
146                node.values.extend(set);
147            } else {
148                entries.push(IntervalNode {
149                    high: coord,
150                    values: set,
151                });
152                self.size += 1;
153            }
154        }
155    }
156}
157
158pub struct BTreeEntry<'a, T: Clone + Eq + std::hash::Hash> {
159    tree: &'a mut IntervalTree<T>,
160    low: u32,
161    high: u32,
162}
163
164impl<'a, T: Clone + Eq + std::hash::Hash> BTreeEntry<'a, T> {
165    pub fn or_insert_with<F>(self, f: F) -> &'a mut HashSet<T>
166    where
167        F: FnOnce() -> HashSet<T>,
168    {
169        if self.tree.get_mut(self.low, self.high).is_none() {
170            let values = f();
171            let entries = self.tree.map.entry(self.low).or_default();
172            entries.push(IntervalNode {
173                high: self.high,
174                values,
175            });
176            self.tree.size += 1;
177        }
178        self.tree.get_mut(self.low, self.high).unwrap()
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_insert_and_query_point_interval() {
188        let mut tree = IntervalTree::new();
189        tree.insert(5, 5, 100);
190
191        let results = tree.query(5, 5);
192        assert_eq!(results.len(), 1);
193        assert_eq!(results[0].0, 5);
194        assert_eq!(results[0].1, 5);
195        assert!(results[0].2.contains(&100));
196    }
197
198    #[test]
199    fn test_insert_and_query_range() {
200        let mut tree = IntervalTree::new();
201        tree.insert(10, 20, 1);
202        tree.insert(15, 25, 2);
203        tree.insert(30, 40, 3);
204
205        // Query overlapping with first two intervals
206        let results = tree.query(12, 22);
207        assert_eq!(results.len(), 2);
208
209        // Query overlapping with only the third interval
210        let results = tree.query(35, 45);
211        assert_eq!(results.len(), 1);
212        assert!(results[0].2.contains(&3));
213    }
214
215    #[test]
216    fn test_remove_value() {
217        let mut tree = IntervalTree::new();
218        tree.insert(5, 5, 100);
219        tree.insert(5, 5, 200);
220
221        assert_eq!(tree.query(5, 5).len(), 1);
222        assert_eq!(tree.query(5, 5)[0].2.len(), 2);
223
224        tree.remove(5, 5, &100);
225
226        let results = tree.query(5, 5);
227        assert_eq!(results.len(), 1);
228        assert_eq!(results[0].2.len(), 1);
229        assert!(results[0].2.contains(&200));
230    }
231
232    #[test]
233    fn test_entry_api() {
234        let mut tree: IntervalTree<i32> = IntervalTree::new();
235
236        tree.entry(10, 10).or_insert_with(HashSet::new).insert(42);
237
238        tree.entry(10, 10).or_insert_with(HashSet::new).insert(43);
239
240        let results = tree.query(10, 10);
241        assert_eq!(results.len(), 1);
242        assert_eq!(results[0].2.len(), 2);
243        assert!(results[0].2.contains(&42));
244        assert!(results[0].2.contains(&43));
245    }
246
247    #[test]
248    fn test_large_sparse_tree() {
249        let mut tree = IntervalTree::new();
250
251        // Simulate sparse spreadsheet
252        for i in (0..1_000_000).step_by(10000) {
253            tree.insert(i, i, i as i32);
254        }
255
256        assert_eq!(tree.len(), 100);
257
258        // Query for high rows
259        let results = tree.query(500_000, u32::MAX);
260        assert_eq!(results.len(), 50);
261    }
262
263    #[test]
264    fn test_entry_recursion_bug() {
265        let mut tree: IntervalTree<u32> = IntervalTree::new();
266
267        // The bug happens when we insert a value, then use entry()
268        // on a coordinate that would be a child of that value.
269        let count: u32 = 5000;
270        for i in 0..count {
271            tree.entry(i, i).or_insert_with(HashSet::new);
272        }
273
274        assert_eq!(tree.len(), count as usize);
275    }
276
277    #[test]
278    fn test_complex_overlaps() {
279        let mut tree = IntervalTree::new();
280        // Nested intervals
281        tree.insert(10, 100, "A");
282        tree.insert(20, 50, "B");
283        tree.insert(30, 40, "C");
284
285        // Partially overlapping
286        tree.insert(5, 15, "D");
287        tree.insert(95, 105, "E");
288
289        // Query for the very middle
290        let results = tree.query(35, 35);
291        assert_eq!(results.len(), 3); // Should hit A, B, and C
292
293        // Query for a range that only hits the "tail" of the large interval and the "head" of the end interval
294        let results = tree.query(98, 102);
295        assert_eq!(results.len(), 2); // Should hit A and E
296    }
297
298    #[test]
299    fn test_multiple_values_and_size() {
300        let mut tree = IntervalTree::new();
301
302        // Insert same interval twice with different values
303        tree.insert(10, 10, "val1");
304        tree.insert(10, 10, "val2");
305        assert_eq!(tree.len(), 1); // Size should only count unique intervals
306
307        // Insert same value twice
308        tree.insert(10, 10, "val1");
309        assert_eq!(tree.len(), 1);
310        let results = tree.query(10, 10);
311        assert_eq!(results[0].2.len(), 2); // HashSet handles the duplicate value "val1"
312    }
313
314    #[test]
315    fn test_remove_edge_cases() {
316        let mut tree = IntervalTree::new();
317        tree.insert(10, 20, "A");
318
319        // Try to remove a value that isn't there
320        let removed = tree.remove(10, 20, &"B");
321        assert!(!removed);
322        assert_eq!(tree.query(10, 20)[0].2.len(), 1);
323
324        // Try to remove from an interval that doesn't exist
325        let removed = tree.remove(99, 100, &"A");
326        assert!(!removed);
327    }
328
329    #[test]
330    fn test_bulk_build_consistency() {
331        let mut incremental_tree = IntervalTree::new();
332        let mut bulk_tree = IntervalTree::new();
333
334        let data: Vec<(u32, HashSet<&str>)> = vec![
335            (10, vec!["A", "B"].into_iter().collect()),
336            (20, vec!["C"].into_iter().collect()),
337            (5, vec!["D"].into_iter().collect()),
338        ];
339
340        // Build incrementally
341        for (coord, values) in &data {
342            for val in values {
343                incremental_tree.insert(*coord, *coord, *val);
344            }
345        }
346
347        // Build using bulk
348        bulk_tree.bulk_build_points(data.clone());
349
350        // Compare results
351        assert_eq!(incremental_tree.len(), bulk_tree.len());
352        assert_eq!(incremental_tree.query(0, 100), bulk_tree.query(0, 100));
353    }
354
355    #[test]
356    fn test_query_stack_safety() {
357        let mut tree = IntervalTree::new();
358        let count = 10_000;
359
360        // Create a deep right-leaning tree
361        for i in 0..count {
362            tree.insert(i, i, i);
363        }
364
365        // Query the very end of the tree
366        // If this causes a SIGABRT, it means query_node() must be made iterative
367        let results = tree.query(count - 1, count - 1);
368        assert_eq!(results.len(), 1);
369    }
370
371    #[test]
372    fn test_empty_and_boundaries() {
373        let mut tree: IntervalTree<i32> = IntervalTree::new();
374
375        assert!(tree.is_empty());
376        assert_eq!(tree.query(0, 100).len(), 0);
377        assert!(!tree.remove(0, 0, &1));
378
379        // Test a query that "misses" everything
380        tree.insert(50, 60, 1);
381        assert_eq!(tree.query(0, 49).len(), 0);
382        assert_eq!(tree.query(61, 100).len(), 0);
383    }
384
385    #[test]
386    fn test_multi_value_interval_size_tracking() {
387        let mut tree = IntervalTree::new();
388        let iv = (10, 20);
389
390        // 1. Insert two values for the same interval
391        // Destructure the tuple into low (iv.0) and high (iv.1)
392        tree.insert(iv.0, iv.1, "A");
393        tree.insert(iv.0, iv.1, "B");
394        assert_eq!(tree.len(), 1, "Should be 1 unique interval");
395
396        // 2. Remove first value - pass as reference &"A"
397        assert!(tree.remove(iv.0, iv.1, &"A"));
398        assert_eq!(
399            tree.len(),
400            1,
401            "Should still be 1 interval after partial removal"
402        );
403
404        // 3. Remove second value - size should now be 0
405        assert!(tree.remove(iv.0, iv.1, &"B"));
406        assert_eq!(tree.len(), 0, "Should be 0 after last value removed");
407    }
408}