formualizer_eval/engine/arena/
ast.rs

1/// AST arena with structural sharing and deduplication
2/// Stores formula AST nodes efficiently with content-addressable storage
3use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use rustc_hash::FxHashMap;
6use std::collections::hash_map::DefaultHasher;
7use std::fmt;
8use std::hash::{Hash, Hasher};
9
10/// Reference to an AST node in the arena
11#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
12pub struct AstNodeId(u32);
13
14impl AstNodeId {
15    pub fn as_u32(self) -> u32 {
16        self.0
17    }
18}
19
20impl fmt::Display for AstNodeId {
21    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22        write!(f, "AstNode({})", self.0)
23    }
24}
25
26/// Compact representation of AST nodes in the arena
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub enum AstNodeData {
29    /// Literal value
30    Literal(ValueRef),
31
32    /// Cell or range reference
33    Reference {
34        original_id: StringId,    // Original reference string
35        ref_type: CompactRefType, // Compact reference representation
36    },
37
38    /// Unary operation
39    UnaryOp { op_id: StringId, expr_id: AstNodeId },
40
41    /// Binary operation
42    BinaryOp {
43        op_id: StringId,
44        left_id: AstNodeId,
45        right_id: AstNodeId,
46    },
47
48    /// Function call
49    Function {
50        name_id: StringId,
51        args_offset: u32, // Index into args array
52        args_count: u16,  // Number of arguments
53    },
54
55    /// Array literal
56    Array {
57        rows: u16,
58        cols: u16,
59        elements_offset: u32, // Index into elements array
60    },
61}
62
63/// Identifies a sheet either by stable registry id or by unresolved name.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub enum SheetKey {
66    Id(u16),
67    Name(StringId),
68}
69
70/// Compact representation of reference types
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum CompactRefType {
73    Cell {
74        sheet: Option<SheetKey>,
75        row: u32,
76        col: u32,
77    },
78    Range {
79        sheet: Option<SheetKey>,
80        start_row: u32,
81        start_col: u32,
82        end_row: u32,
83        end_col: u32,
84    },
85    NamedRange(StringId),
86    Table(StringId),
87}
88
89/// Arena for storing AST nodes with deduplication
90pub struct AstArena {
91    /// Node storage
92    nodes: Vec<AstNodeData>,
93
94    /// Hash -> node index for deduplication
95    dedup_map: FxHashMap<u64, AstNodeId>,
96
97    /// Function arguments storage (flattened)
98    function_args: Vec<AstNodeId>,
99
100    /// Array elements storage (flattened)
101    array_elements: Vec<AstNodeId>,
102
103    /// String pool for operators and function names
104    strings: StringInterner,
105
106    /// Statistics
107    dedup_hits: usize,
108}
109
110impl AstArena {
111    pub fn new() -> Self {
112        Self {
113            nodes: Vec::new(),
114            dedup_map: FxHashMap::default(),
115            function_args: Vec::new(),
116            array_elements: Vec::new(),
117            strings: StringInterner::new(),
118            dedup_hits: 0,
119        }
120    }
121
122    pub fn with_capacity(node_cap: usize) -> Self {
123        Self {
124            nodes: Vec::with_capacity(node_cap),
125            dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
126            function_args: Vec::with_capacity(node_cap * 2), // Assume avg 2 args
127            array_elements: Vec::with_capacity(node_cap),
128            strings: StringInterner::with_capacity(node_cap / 10),
129            dedup_hits: 0,
130        }
131    }
132
133    /// Insert a node, deduplicating if it already exists
134    pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
135        // Compute hash
136        let hash = self.hash_node(&node);
137
138        // Check for existing node
139        if let Some(&id) = self.dedup_map.get(&hash) {
140            // Verify it's actually the same (handle hash collisions)
141            if self.nodes[id.0 as usize] == node {
142                self.dedup_hits += 1;
143                return id;
144            }
145        }
146
147        // Add new node
148        let id = AstNodeId(self.nodes.len() as u32);
149        self.nodes.push(node);
150        self.dedup_map.insert(hash, id);
151        id
152    }
153
154    /// Insert a literal node
155    pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
156        self.insert(AstNodeData::Literal(value))
157    }
158
159    /// Insert a reference node
160    pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
161        let original_id = self.strings.intern(original);
162        self.insert(AstNodeData::Reference {
163            original_id,
164            ref_type,
165        })
166    }
167
168    /// Insert a unary operation node
169    pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
170        let op_id = self.strings.intern(op);
171        self.insert(AstNodeData::UnaryOp {
172            op_id,
173            expr_id: expr,
174        })
175    }
176
177    /// Insert a binary operation node
178    pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
179        let op_id = self.strings.intern(op);
180        self.insert(AstNodeData::BinaryOp {
181            op_id,
182            left_id: left,
183            right_id: right,
184        })
185    }
186
187    /// Insert a function call node
188    pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
189        let name_id = self.strings.intern(name);
190        let args_offset = self.function_args.len() as u32;
191        let args_count = args.len() as u16;
192
193        self.function_args.extend(args);
194
195        self.insert(AstNodeData::Function {
196            name_id,
197            args_offset,
198            args_count,
199        })
200    }
201
202    /// Insert an array literal node
203    pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
204        assert_eq!(
205            elements.len(),
206            (rows * cols) as usize,
207            "Array dimensions don't match element count"
208        );
209
210        let elements_offset = self.array_elements.len() as u32;
211        self.array_elements.extend(elements);
212
213        self.insert(AstNodeData::Array {
214            rows,
215            cols,
216            elements_offset,
217        })
218    }
219
220    /// Get a node by ID
221    pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
222        self.nodes.get(id.0 as usize)
223    }
224
225    /// Get function arguments for a function node
226    pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
227        match self.get(id)? {
228            AstNodeData::Function {
229                args_offset,
230                args_count,
231                ..
232            } => {
233                let start = *args_offset as usize;
234                let end = start + *args_count as usize;
235                Some(&self.function_args[start..end])
236            }
237            _ => None,
238        }
239    }
240
241    /// Get array elements for an array node
242    pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
243        match self.get(id)? {
244            AstNodeData::Array {
245                rows,
246                cols,
247                elements_offset,
248            } => {
249                let start = *elements_offset as usize;
250                let count = (*rows * *cols) as usize;
251                let end = start + count;
252                Some(&self.array_elements[start..end])
253            }
254            _ => None,
255        }
256    }
257
258    /// Resolve a string ID to its content
259    pub fn resolve_string(&self, id: StringId) -> &str {
260        self.strings.resolve(id)
261    }
262
263    /// Get the string interner (for external use)
264    pub fn strings(&self) -> &StringInterner {
265        &self.strings
266    }
267
268    /// Get mutable access to the string interner
269    pub fn strings_mut(&mut self) -> &mut StringInterner {
270        &mut self.strings
271    }
272
273    /// Compute hash for a node
274    fn hash_node(&self, node: &AstNodeData) -> u64 {
275        let mut hasher = DefaultHasher::new();
276        node.hash(&mut hasher);
277        hasher.finish()
278    }
279
280    /// Get statistics about the arena
281    pub fn stats(&self) -> AstArenaStats {
282        AstArenaStats {
283            node_count: self.nodes.len(),
284            dedup_hits: self.dedup_hits,
285            string_count: self.strings.len(),
286            total_args: self.function_args.len(),
287            total_array_elements: self.array_elements.len(),
288        }
289    }
290
291    /// Returns memory usage in bytes (approximate)
292    pub fn memory_usage(&self) -> usize {
293        self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
294            + self.dedup_map.capacity() * (8 + 4) // hash + id
295            + self.function_args.capacity() * 4
296            + self.array_elements.capacity() * 4
297            + self.strings.memory_usage()
298    }
299
300    /// Clear all nodes from the arena
301    pub fn clear(&mut self) {
302        self.nodes.clear();
303        self.dedup_map.clear();
304        self.function_args.clear();
305        self.array_elements.clear();
306        self.strings.clear();
307        self.dedup_hits = 0;
308    }
309}
310
311impl Default for AstArena {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317impl fmt::Debug for AstArena {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        f.debug_struct("AstArena")
320            .field("nodes", &self.nodes.len())
321            .field("dedup_hits", &self.dedup_hits)
322            .field("strings", &self.strings.len())
323            .finish()
324    }
325}
326
327/// Statistics about the AST arena
328#[derive(Debug, Clone)]
329pub struct AstArenaStats {
330    pub node_count: usize,
331    pub dedup_hits: usize,
332    pub string_count: usize,
333    pub total_args: usize,
334    pub total_array_elements: usize,
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_ast_arena_literal() {
343        let mut arena = AstArena::new();
344
345        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
346        let lit2 = arena.insert_literal(ValueRef::boolean(true));
347
348        assert_ne!(lit1, lit2);
349
350        match arena.get(lit1) {
351            Some(AstNodeData::Literal(v)) => {
352                assert_eq!(v.as_small_int(), Some(42));
353            }
354            _ => panic!("Expected literal node"),
355        }
356    }
357
358    #[test]
359    fn test_ast_arena_deduplication() {
360        let mut arena = AstArena::new();
361
362        // Insert same literal twice
363        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
364        let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
365
366        assert_eq!(lit1, lit2); // Should be deduplicated
367        assert_eq!(arena.stats().dedup_hits, 1);
368    }
369
370    #[test]
371    fn test_ast_arena_binary_op() {
372        let mut arena = AstArena::new();
373
374        let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
375        let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
376        let add = arena.insert_binary_op("+", left, right);
377
378        match arena.get(add) {
379            Some(AstNodeData::BinaryOp {
380                op_id,
381                left_id,
382                right_id,
383            }) => {
384                assert_eq!(arena.resolve_string(*op_id), "+");
385                assert_eq!(*left_id, left);
386                assert_eq!(*right_id, right);
387            }
388            _ => panic!("Expected binary op node"),
389        }
390    }
391
392    #[test]
393    fn test_ast_arena_function() {
394        let mut arena = AstArena::new();
395
396        let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
397        let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
398        let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
399
400        let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
401
402        match arena.get(func) {
403            Some(AstNodeData::Function {
404                name_id,
405                args_count,
406                ..
407            }) => {
408                assert_eq!(arena.resolve_string(*name_id), "SUM");
409                assert_eq!(*args_count, 3);
410            }
411            _ => panic!("Expected function node"),
412        }
413
414        let args = arena.get_function_args(func).unwrap();
415        assert_eq!(args, &[arg1, arg2, arg3]);
416    }
417
418    #[test]
419    fn test_ast_arena_structural_sharing() {
420        let mut arena = AstArena::new();
421
422        // Create "A1" reference that will be shared
423        let a1_ref = arena.insert_reference(
424            "A1",
425            CompactRefType::Cell {
426                sheet: None,
427                row: 1,
428                col: 1,
429            },
430        );
431
432        // Create "A1 + 1"
433        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
434        let expr1 = arena.insert_binary_op("+", a1_ref, one);
435
436        // Create "A1 * 2"
437        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
438        let expr2 = arena.insert_binary_op("*", a1_ref, two);
439
440        // A1 reference should be shared
441        assert_eq!(arena.stats().node_count, 5); // A1, 1, +expr, 2, *expr
442
443        // Try to insert A1 again - should be deduplicated
444        let a1_ref2 = arena.insert_reference(
445            "A1",
446            CompactRefType::Cell {
447                sheet: None,
448                row: 1,
449                col: 1,
450            },
451        );
452        assert_eq!(a1_ref, a1_ref2);
453    }
454
455    #[test]
456    fn test_ast_arena_array() {
457        let mut arena = AstArena::new();
458
459        let elements = vec![
460            arena.insert_literal(ValueRef::small_int(1).unwrap()),
461            arena.insert_literal(ValueRef::small_int(2).unwrap()),
462            arena.insert_literal(ValueRef::small_int(3).unwrap()),
463            arena.insert_literal(ValueRef::small_int(4).unwrap()),
464        ];
465
466        let array = arena.insert_array(2, 2, elements.clone());
467
468        match arena.get(array) {
469            Some(AstNodeData::Array { rows, cols, .. }) => {
470                assert_eq!(*rows, 2);
471                assert_eq!(*cols, 2);
472            }
473            _ => panic!("Expected array node"),
474        }
475
476        let stored_elements = arena.get_array_elements(array).unwrap();
477        assert_eq!(stored_elements, &elements[..]);
478    }
479
480    #[test]
481    fn test_ast_arena_complex_expression() {
482        let mut arena = AstArena::new();
483
484        // Build: SUM(A1:A10) + IF(B1 > 0, C1, D1)
485
486        // A1:A10 range
487        let range = arena.insert_reference(
488            "A1:A10",
489            CompactRefType::Range {
490                sheet: None,
491                start_row: 1,
492                start_col: 1,
493                end_row: 10,
494                end_col: 1,
495            },
496        );
497
498        // SUM(A1:A10)
499        let sum = arena.insert_function("SUM", vec![range]);
500
501        // B1 reference
502        let b1 = arena.insert_reference(
503            "B1",
504            CompactRefType::Cell {
505                sheet: None,
506                row: 1,
507                col: 2,
508            },
509        );
510
511        // 0 literal
512        let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
513
514        // B1 > 0
515        let condition = arena.insert_binary_op(">", b1, zero);
516
517        // C1 and D1 references
518        let c1 = arena.insert_reference(
519            "C1",
520            CompactRefType::Cell {
521                sheet: None,
522                row: 1,
523                col: 3,
524            },
525        );
526        let d1 = arena.insert_reference(
527            "D1",
528            CompactRefType::Cell {
529                sheet: None,
530                row: 1,
531                col: 4,
532            },
533        );
534
535        // IF(B1 > 0, C1, D1)
536        let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
537
538        // Final: SUM(...) + IF(...)
539        let final_expr = arena.insert_binary_op("+", sum, if_expr);
540
541        // Verify structure
542        assert!(arena.get(final_expr).is_some());
543        // Note: zero literal gets deduplicated if used multiple times
544        // We have: range, sum, b1, zero, condition(>), c1, d1, if_expr, final_expr(+)
545        // That's 9 unique nodes (zero is deduplicated)
546        assert_eq!(arena.stats().node_count, 9); // All unique nodes except deduplicated zero
547    }
548
549    #[test]
550    fn test_ast_arena_string_deduplication() {
551        let mut arena = AstArena::new();
552
553        // Use same operator multiple times
554        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
555        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
556        let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
557
558        let add1 = arena.insert_binary_op("+", one, two);
559        let add2 = arena.insert_binary_op("+", two, three);
560        let add3 = arena.insert_binary_op("+", one, three);
561
562        // "+" should be interned only once
563        assert_eq!(arena.strings().len(), 1);
564    }
565
566    #[test]
567    fn test_ast_arena_clear() {
568        let mut arena = AstArena::new();
569
570        arena.insert_literal(ValueRef::small_int(1).unwrap());
571        arena.insert_literal(ValueRef::small_int(2).unwrap());
572        let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
573        let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
574        arena.insert_binary_op("+", left, right);
575
576        assert_eq!(arena.stats().node_count, 5);
577
578        arena.clear();
579
580        assert_eq!(arena.stats().node_count, 0);
581        assert_eq!(arena.strings().len(), 0);
582    }
583}