formualizer_eval/engine/
interval_tree.rs

1use std::collections::HashSet;
2
3/// Custom interval tree optimized for spreadsheet cell indexing.
4///
5/// ## Design decisions:
6///
7/// 1. **Point intervals are the common case** - Most cells are single points [r,r] or [c,c]
8/// 2. **Sparse data** - Even million-row sheets typically have <10K cells
9/// 3. **Batch updates** - During shifts, we update many intervals at once
10/// 4. **Small value sets** - Each interval maps to a small set of VertexIds
11///
12/// ## Implementation:
13///
14/// Uses an augmented BST where each node stores:
15/// - Interval [low, high]
16/// - Max endpoint in subtree (for efficient pruning)
17/// - Value set (HashSet<VertexId>)
18///
19/// This is simpler than generic interval trees because we optimize for our specific use case.
20#[derive(Debug, Clone)]
21pub struct IntervalTree<T: Clone + Eq + std::hash::Hash> {
22    root: Option<Box<Node<T>>>,
23    size: usize,
24}
25
26#[derive(Debug, Clone)]
27struct Node<T: Clone + Eq + std::hash::Hash> {
28    /// The interval [low, high]
29    low: u32,
30    high: u32,
31    /// Maximum high value in this subtree (for query pruning)
32    max_high: u32,
33    /// Values associated with this interval
34    values: HashSet<T>,
35    /// Left child (intervals with smaller low value)
36    left: Option<Box<Node<T>>>,
37    /// Right child (intervals with larger low value)
38    right: Option<Box<Node<T>>>,
39}
40
41impl<T: Clone + Eq + std::hash::Hash> IntervalTree<T> {
42    /// Create a new empty interval tree
43    pub fn new() -> Self {
44        Self {
45            root: None,
46            size: 0,
47        }
48    }
49
50    /// Insert a value for the given interval [low, high]
51    pub fn insert(&mut self, low: u32, high: u32, value: T) {
52        if let Some(root) = &mut self.root {
53            if Self::insert_into_node(root, low, high, value) {
54                self.size += 1;
55            }
56        } else {
57            let mut values = HashSet::new();
58            values.insert(value);
59            self.root = Some(Box::new(Node {
60                low,
61                high,
62                max_high: high,
63                values,
64                left: None,
65                right: None,
66            }));
67            self.size = 1;
68        }
69    }
70
71    /// Insert into a node, returns true if a new interval was created
72    fn insert_into_node(node: &mut Box<Node<T>>, low: u32, high: u32, value: T) -> bool {
73        // Update max_high if needed
74        if high > node.max_high {
75            node.max_high = high;
76        }
77
78        // Check if this is the same interval
79        if low == node.low && high == node.high {
80            // Add value to existing interval
81            node.values.insert(value);
82            return false; // No new interval created
83        }
84
85        // Decide which subtree to insert into based on low value
86        if low < node.low {
87            if let Some(left) = &mut node.left {
88                Self::insert_into_node(left, low, high, value)
89            } else {
90                let mut values = HashSet::new();
91                values.insert(value);
92                node.left = Some(Box::new(Node {
93                    low,
94                    high,
95                    max_high: high,
96                    values,
97                    left: None,
98                    right: None,
99                }));
100                true
101            }
102        } else if let Some(right) = &mut node.right {
103            Self::insert_into_node(right, low, high, value)
104        } else {
105            let mut values = HashSet::new();
106            values.insert(value);
107            node.right = Some(Box::new(Node {
108                low,
109                high,
110                max_high: high,
111                values,
112                left: None,
113                right: None,
114            }));
115            true
116        }
117    }
118
119    /// Remove a value from the interval [low, high]
120    pub fn remove(&mut self, low: u32, high: u32, value: &T) -> bool {
121        if let Some(root) = &mut self.root {
122            Self::remove_from_node(root, low, high, value)
123        } else {
124            false
125        }
126    }
127
128    fn remove_from_node(node: &mut Box<Node<T>>, low: u32, high: u32, value: &T) -> bool {
129        if low == node.low && high == node.high {
130            return node.values.remove(value);
131        }
132
133        if low < node.low {
134            if let Some(left) = &mut node.left {
135                return Self::remove_from_node(left, low, high, value);
136            }
137        } else if let Some(right) = &mut node.right {
138            return Self::remove_from_node(right, low, high, value);
139        }
140
141        false
142    }
143
144    /// Query all intervals that overlap with [query_low, query_high]
145    pub fn query(&self, query_low: u32, query_high: u32) -> Vec<(u32, u32, HashSet<T>)> {
146        let mut results = Vec::new();
147        if let Some(root) = &self.root {
148            Self::query_node(root, query_low, query_high, &mut results);
149        }
150        results
151    }
152
153    fn query_node(
154        node: &Node<T>,
155        query_low: u32,
156        query_high: u32,
157        results: &mut Vec<(u32, u32, HashSet<T>)>,
158    ) {
159        // Check if this node's interval overlaps with query
160        if node.low <= query_high && node.high >= query_low {
161            results.push((node.low, node.high, node.values.clone()));
162        }
163
164        // Check left subtree if it might contain overlapping intervals
165        if let Some(left) = &node.left {
166            // Only traverse left if its max_high could overlap
167            if left.max_high >= query_low {
168                Self::query_node(left, query_low, query_high, results);
169            }
170        }
171
172        // Check right subtree if it might contain overlapping intervals
173        if let Some(right) = &node.right {
174            // Only traverse right if the query extends beyond this node's low
175            if query_high >= node.low {
176                Self::query_node(right, query_low, query_high, results);
177            }
178        }
179    }
180
181    /// Get mutable reference to values for an exact interval match
182    pub fn get_mut(&mut self, low: u32, high: u32) -> Option<&mut HashSet<T>> {
183        if let Some(root) = &mut self.root {
184            Self::get_mut_in_node(root, low, high)
185        } else {
186            None
187        }
188    }
189
190    fn get_mut_in_node(node: &mut Box<Node<T>>, low: u32, high: u32) -> Option<&mut HashSet<T>> {
191        if low == node.low && high == node.high {
192            return Some(&mut node.values);
193        }
194
195        if low < node.low {
196            if let Some(left) = &mut node.left {
197                return Self::get_mut_in_node(left, low, high);
198            }
199        } else if let Some(right) = &mut node.right {
200            return Self::get_mut_in_node(right, low, high);
201        }
202
203        None
204    }
205
206    /// Check if the tree is empty
207    pub fn is_empty(&self) -> bool {
208        self.root.is_none()
209    }
210
211    /// Get the number of intervals in the tree
212    pub fn len(&self) -> usize {
213        self.size
214    }
215
216    /// Clear all intervals from the tree
217    pub fn clear(&mut self) {
218        self.root = None;
219        self.size = 0;
220    }
221
222    /// Entry API for convenient insert-or-update operations
223    pub fn entry(&mut self, low: u32, high: u32) -> Entry<'_, T> {
224        Entry {
225            tree: self,
226            low,
227            high,
228        }
229    }
230
231    /// Bulk build optimization for a collection of point intervals [x,x].
232    /// Expects (low == high) for all items. Existing content is discarded if tree is empty; if not empty, falls back to incremental inserts.
233    pub fn bulk_build_points(&mut self, mut items: Vec<(u32, std::collections::HashSet<T>)>) {
234        if self.root.is_some() {
235            // Fallback: incremental insert to preserve existing nodes
236            for (k, set) in items.into_iter() {
237                for v in set {
238                    self.insert(k, k, v);
239                }
240            }
241            return;
242        }
243        if items.is_empty() {
244            return;
245        }
246        // Sort by coordinate to build balanced tree
247        items.sort_by_key(|(k, _)| *k);
248        // Deduplicate keys by merging sets
249        let mut dedup: Vec<(u32, std::collections::HashSet<T>)> = Vec::with_capacity(items.len());
250        for (k, set) in items.into_iter() {
251            if let Some(last) = dedup.last_mut() {
252                if last.0 == k {
253                    last.1.extend(set);
254                    continue;
255                }
256            }
257            dedup.push((k, set));
258        }
259        fn build_balanced<T: Clone + Eq + std::hash::Hash>(
260            slice: &[(u32, std::collections::HashSet<T>)],
261        ) -> Option<Box<Node<T>>> {
262            if slice.is_empty() {
263                return None;
264            }
265            let mid = slice.len() / 2;
266            let (low, values) = (&slice[mid].0, &slice[mid].1);
267            let left = build_balanced(&slice[..mid]);
268            let right = build_balanced(&slice[mid + 1..]);
269            // max_high is same as low (point interval); but need subtree max
270            let mut max_high = *low;
271            if let Some(ref l) = left {
272                if l.max_high > max_high {
273                    max_high = l.max_high;
274                }
275            }
276            if let Some(ref r) = right {
277                if r.max_high > max_high {
278                    max_high = r.max_high;
279                }
280            }
281            Some(Box::new(Node {
282                low: *low,
283                high: *low,
284                max_high,
285                values: values.clone(),
286                left,
287                right,
288            }))
289        }
290        self.size = dedup.len();
291        self.root = build_balanced(&dedup);
292    }
293}
294
295impl<T: Clone + Eq + std::hash::Hash> Default for IntervalTree<T> {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301/// Entry API for interval tree
302pub struct Entry<'a, T: Clone + Eq + std::hash::Hash> {
303    tree: &'a mut IntervalTree<T>,
304    low: u32,
305    high: u32,
306}
307
308impl<'a, T: Clone + Eq + std::hash::Hash> Entry<'a, T> {
309    /// Get or insert an empty HashSet for this interval
310    pub fn or_insert_with<F>(self, f: F) -> &'a mut HashSet<T>
311    where
312        F: FnOnce() -> HashSet<T>,
313    {
314        // Check if interval exists
315        if self.tree.get_mut(self.low, self.high).is_none() {
316            // Create new node with empty set
317            if let Some(root) = &mut self.tree.root {
318                Self::ensure_interval_exists(root, self.low, self.high);
319            } else {
320                self.tree.root = Some(Box::new(Node {
321                    low: self.low,
322                    high: self.high,
323                    max_high: self.high,
324                    values: f(),
325                    left: None,
326                    right: None,
327                }));
328                self.tree.size = 1;
329            }
330        }
331
332        self.tree.get_mut(self.low, self.high).unwrap()
333    }
334
335    fn ensure_interval_exists(node: &mut Box<Node<T>>, low: u32, high: u32) {
336        if high > node.max_high {
337            node.max_high = high;
338        }
339
340        if low == node.low && high == node.high {
341            return;
342        }
343
344        if low < node.low {
345            if let Some(left) = &mut node.left {
346                Self::ensure_interval_exists(left, low, high);
347            } else {
348                node.left = Some(Box::new(Node {
349                    low,
350                    high,
351                    max_high: high,
352                    values: HashSet::new(),
353                    left: None,
354                    right: None,
355                }));
356            }
357        } else if let Some(right) = &mut node.right {
358            Self::ensure_interval_exists(right, low, high);
359        } else {
360            node.right = Some(Box::new(Node {
361                low,
362                high,
363                max_high: high,
364                values: HashSet::new(),
365                left: None,
366                right: None,
367            }));
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_insert_and_query_point_interval() {
378        let mut tree = IntervalTree::new();
379        tree.insert(5, 5, 100);
380
381        let results = tree.query(5, 5);
382        assert_eq!(results.len(), 1);
383        assert_eq!(results[0].0, 5);
384        assert_eq!(results[0].1, 5);
385        assert!(results[0].2.contains(&100));
386    }
387
388    #[test]
389    fn test_insert_and_query_range() {
390        let mut tree = IntervalTree::new();
391        tree.insert(10, 20, 1);
392        tree.insert(15, 25, 2);
393        tree.insert(30, 40, 3);
394
395        // Query overlapping with first two intervals
396        let results = tree.query(12, 22);
397        assert_eq!(results.len(), 2);
398
399        // Query overlapping with only the third interval
400        let results = tree.query(35, 45);
401        assert_eq!(results.len(), 1);
402        assert!(results[0].2.contains(&3));
403    }
404
405    #[test]
406    fn test_remove_value() {
407        let mut tree = IntervalTree::new();
408        tree.insert(5, 5, 100);
409        tree.insert(5, 5, 200);
410
411        assert_eq!(tree.query(5, 5).len(), 1);
412        assert_eq!(tree.query(5, 5)[0].2.len(), 2);
413
414        tree.remove(5, 5, &100);
415
416        let results = tree.query(5, 5);
417        assert_eq!(results.len(), 1);
418        assert_eq!(results[0].2.len(), 1);
419        assert!(results[0].2.contains(&200));
420    }
421
422    #[test]
423    fn test_entry_api() {
424        let mut tree: IntervalTree<i32> = IntervalTree::new();
425
426        tree.entry(10, 10).or_insert_with(HashSet::new).insert(42);
427
428        tree.entry(10, 10).or_insert_with(HashSet::new).insert(43);
429
430        let results = tree.query(10, 10);
431        assert_eq!(results.len(), 1);
432        assert_eq!(results[0].2.len(), 2);
433        assert!(results[0].2.contains(&42));
434        assert!(results[0].2.contains(&43));
435    }
436
437    #[test]
438    fn test_large_sparse_tree() {
439        let mut tree = IntervalTree::new();
440
441        // Simulate sparse spreadsheet
442        for i in (0..1_000_000).step_by(10000) {
443            tree.insert(i, i, i as i32);
444        }
445
446        assert_eq!(tree.len(), 100);
447
448        // Query for high rows
449        let results = tree.query(500_000, u32::MAX);
450        assert_eq!(results.len(), 50);
451    }
452}