Skip to main content

formualizer_eval/engine/arena/
ast.rs

1/// AST arena with structural sharing and deduplication
2/// Stores formula AST nodes efficiently with content-addressable storage
3use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::{ExternalRefKind, TableSpecifier};
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11/// Reference to an AST node in the arena
12#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16    pub fn as_u32(self) -> u32 {
17        self.0
18    }
19}
20
21impl fmt::Display for AstNodeId {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(f, "AstNode({})", self.0)
24    }
25}
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
28pub struct TableSpecId(u32);
29
30impl TableSpecId {
31    pub fn as_u32(self) -> u32 {
32        self.0
33    }
34}
35
36/// Compact representation of AST nodes in the arena
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum AstNodeData {
39    /// Literal value
40    Literal(ValueRef),
41
42    /// Cell or range reference
43    Reference {
44        original_id: StringId,    // Original reference string
45        ref_type: CompactRefType, // Compact reference representation
46    },
47
48    /// Unary operation
49    UnaryOp { op_id: StringId, expr_id: AstNodeId },
50
51    /// Binary operation
52    BinaryOp {
53        op_id: StringId,
54        left_id: AstNodeId,
55        right_id: AstNodeId,
56    },
57
58    /// Function call
59    Function {
60        name_id: StringId,
61        args_offset: u32, // Index into args array
62        args_count: u16,  // Number of arguments
63    },
64
65    /// Array literal
66    Array {
67        rows: u16,
68        cols: u16,
69        elements_offset: u32, // Index into elements array
70    },
71}
72
73/// Identifies a sheet either by stable registry id or by unresolved name.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SheetKey {
76    Id(u16),
77    Name(StringId),
78}
79
80/// Compact representation of reference types
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub enum CompactRefType {
83    Cell {
84        sheet: Option<SheetKey>,
85        row: u32,
86        col: u32,
87        row_abs: bool,
88        col_abs: bool,
89    },
90    Range {
91        sheet: Option<SheetKey>,
92        start_row: u32,
93        start_col: u32,
94        end_row: u32,
95        end_col: u32,
96        start_row_abs: bool,
97        start_col_abs: bool,
98        end_row_abs: bool,
99        end_col_abs: bool,
100    },
101    External {
102        raw_id: StringId,
103        book_id: StringId,
104        sheet_id: StringId,
105        kind: ExternalRefKind,
106    },
107    NamedRange(StringId),
108    Table {
109        name_id: StringId,
110        specifier_id: Option<TableSpecId>,
111    },
112    /// 3D cell reference (`Sheet1:Sheet3!A1`).
113    Cell3D {
114        sheet_first: StringId,
115        sheet_last: StringId,
116        row: u32,
117        col: u32,
118        row_abs: bool,
119        col_abs: bool,
120    },
121    /// 3D range reference (`Sheet1:Sheet3!A1:B2`).
122    Range3D {
123        sheet_first: StringId,
124        sheet_last: StringId,
125        start_row: u32,
126        start_col: u32,
127        end_row: u32,
128        end_col: u32,
129        start_row_abs: bool,
130        start_col_abs: bool,
131        end_row_abs: bool,
132        end_col_abs: bool,
133    },
134}
135
136/// Arena for storing AST nodes with deduplication
137pub struct AstArena {
138    /// Node storage
139    nodes: Vec<AstNodeData>,
140
141    /// Hash -> node index for deduplication
142    dedup_map: FxHashMap<u64, AstNodeId>,
143
144    /// Function arguments storage (flattened)
145    function_args: Vec<AstNodeId>,
146
147    /// Array elements storage (flattened)
148    array_elements: Vec<AstNodeId>,
149
150    /// String pool for operators and function names
151    strings: StringInterner,
152
153    /// Structured table specifiers
154    table_specs: Vec<TableSpecifier>,
155    table_spec_dedup: FxHashMap<u64, TableSpecId>,
156
157    /// Statistics
158    dedup_hits: usize,
159}
160
161impl AstArena {
162    pub fn new() -> Self {
163        Self {
164            nodes: Vec::new(),
165            dedup_map: FxHashMap::default(),
166            function_args: Vec::new(),
167            array_elements: Vec::new(),
168            strings: StringInterner::new(),
169            table_specs: Vec::new(),
170            table_spec_dedup: FxHashMap::default(),
171            dedup_hits: 0,
172        }
173    }
174
175    pub fn with_capacity(node_cap: usize) -> Self {
176        Self {
177            nodes: Vec::with_capacity(node_cap),
178            dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
179            function_args: Vec::with_capacity(node_cap * 2), // Assume avg 2 args
180            array_elements: Vec::with_capacity(node_cap),
181            strings: StringInterner::with_capacity(node_cap / 10),
182            table_specs: Vec::new(),
183            table_spec_dedup: FxHashMap::default(),
184            dedup_hits: 0,
185        }
186    }
187
188    /// Insert a node, deduplicating if it already exists
189    pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
190        // Compute hash
191        let hash = self.hash_node(&node);
192
193        // Check for existing node
194        if let Some(&id) = self.dedup_map.get(&hash) {
195            // Verify it's actually the same (handle hash collisions)
196            if self.nodes[id.0 as usize] == node {
197                self.dedup_hits += 1;
198                return id;
199            }
200        }
201
202        // Add new node
203        let id = AstNodeId(self.nodes.len() as u32);
204        self.nodes.push(node);
205        self.dedup_map.insert(hash, id);
206        id
207    }
208
209    /// Insert a literal node
210    pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
211        self.insert(AstNodeData::Literal(value))
212    }
213
214    /// Insert a reference node
215    pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
216        let original_id = self.strings.intern(original);
217        self.insert(AstNodeData::Reference {
218            original_id,
219            ref_type,
220        })
221    }
222
223    /// Insert a unary operation node
224    pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
225        let op_id = self.strings.intern(op);
226        self.insert(AstNodeData::UnaryOp {
227            op_id,
228            expr_id: expr,
229        })
230    }
231
232    /// Insert a binary operation node
233    pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
234        let op_id = self.strings.intern(op);
235        self.insert(AstNodeData::BinaryOp {
236            op_id,
237            left_id: left,
238            right_id: right,
239        })
240    }
241
242    /// Insert a function call node
243    pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
244        let name_id = self.strings.intern(name);
245        let args_offset = self.function_args.len() as u32;
246        let args_count = args.len() as u16;
247
248        self.function_args.extend(args);
249
250        self.insert(AstNodeData::Function {
251            name_id,
252            args_offset,
253            args_count,
254        })
255    }
256
257    /// Insert an array literal node
258    pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
259        assert_eq!(
260            elements.len(),
261            (rows * cols) as usize,
262            "Array dimensions don't match element count"
263        );
264
265        let elements_offset = self.array_elements.len() as u32;
266        self.array_elements.extend(elements);
267
268        self.insert(AstNodeData::Array {
269            rows,
270            cols,
271            elements_offset,
272        })
273    }
274
275    /// Get a node by ID
276    pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
277        self.nodes.get(id.0 as usize)
278    }
279
280    /// Get function arguments for a function node
281    pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
282        match self.get(id)? {
283            AstNodeData::Function {
284                args_offset,
285                args_count,
286                ..
287            } => {
288                let start = *args_offset as usize;
289                let end = start + *args_count as usize;
290                Some(&self.function_args[start..end])
291            }
292            _ => None,
293        }
294    }
295
296    /// Get array elements for an array node
297    pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
298        match self.get(id)? {
299            AstNodeData::Array {
300                rows,
301                cols,
302                elements_offset,
303            } => {
304                let start = *elements_offset as usize;
305                let count = (*rows * *cols) as usize;
306                let end = start + count;
307                Some(&self.array_elements[start..end])
308            }
309            _ => None,
310        }
311    }
312
313    pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
314        match self.get(id)? {
315            AstNodeData::Array { rows, cols, .. } => {
316                let elements = self.get_array_elements(id)?;
317                Some((*rows, *cols, elements))
318            }
319            _ => None,
320        }
321    }
322
323    /// Resolve a string ID to its content
324    pub fn resolve_string(&self, id: StringId) -> &str {
325        self.strings.resolve(id)
326    }
327
328    /// Get the string interner (for external use)
329    pub fn strings(&self) -> &StringInterner {
330        &self.strings
331    }
332
333    /// Get mutable access to the string interner
334    pub fn strings_mut(&mut self) -> &mut StringInterner {
335        &mut self.strings
336    }
337
338    pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
339        let hash = {
340            let mut hasher = DefaultHasher::new();
341            specifier.hash(&mut hasher);
342            hasher.finish()
343        };
344
345        if let Some(&id) = self.table_spec_dedup.get(&hash)
346            && self
347                .table_specs
348                .get(id.0 as usize)
349                .is_some_and(|existing| existing == specifier)
350        {
351            return id;
352        }
353
354        let id = TableSpecId(self.table_specs.len() as u32);
355        self.table_specs.push(specifier.clone());
356        self.table_spec_dedup.insert(hash, id);
357        id
358    }
359
360    pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
361        self.table_specs.get(id.0 as usize)
362    }
363
364    /// Compute hash for a node
365    fn hash_node(&self, node: &AstNodeData) -> u64 {
366        let mut hasher = DefaultHasher::new();
367        node.hash(&mut hasher);
368        hasher.finish()
369    }
370
371    /// Get statistics about the arena
372    pub fn stats(&self) -> AstArenaStats {
373        AstArenaStats {
374            node_count: self.nodes.len(),
375            dedup_hits: self.dedup_hits,
376            string_count: self.strings.len(),
377            table_spec_count: self.table_specs.len(),
378            total_args: self.function_args.len(),
379            total_array_elements: self.array_elements.len(),
380        }
381    }
382
383    /// Returns memory usage in bytes (approximate)
384    pub fn memory_usage(&self) -> usize {
385        self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
386            + self.dedup_map.capacity() * (8 + 4) // hash + id
387            + self.function_args.capacity() * 4
388            + self.array_elements.capacity() * 4
389            + self.strings.memory_usage()
390            + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
391            + self.table_spec_dedup.capacity() * (8 + 4)
392    }
393
394    /// Clear all nodes from the arena
395    pub fn clear(&mut self) {
396        self.nodes.clear();
397        self.dedup_map.clear();
398        self.function_args.clear();
399        self.array_elements.clear();
400        self.strings.clear();
401        self.table_specs.clear();
402        self.table_spec_dedup.clear();
403        self.dedup_hits = 0;
404    }
405}
406
407impl Default for AstArena {
408    fn default() -> Self {
409        Self::new()
410    }
411}
412
413impl fmt::Debug for AstArena {
414    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415        f.debug_struct("AstArena")
416            .field("nodes", &self.nodes.len())
417            .field("dedup_hits", &self.dedup_hits)
418            .field("strings", &self.strings.len())
419            .finish()
420    }
421}
422
423/// Statistics about the AST arena
424#[derive(Debug, Clone)]
425pub struct AstArenaStats {
426    pub node_count: usize,
427    pub dedup_hits: usize,
428    pub string_count: usize,
429    pub table_spec_count: usize,
430    pub total_args: usize,
431    pub total_array_elements: usize,
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_ast_arena_literal() {
440        let mut arena = AstArena::new();
441
442        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
443        let lit2 = arena.insert_literal(ValueRef::boolean(true));
444
445        assert_ne!(lit1, lit2);
446
447        match arena.get(lit1) {
448            Some(AstNodeData::Literal(v)) => {
449                assert_eq!(v.as_small_int(), Some(42));
450            }
451            _ => panic!("Expected literal node"),
452        }
453    }
454
455    #[test]
456    fn test_ast_arena_deduplication() {
457        let mut arena = AstArena::new();
458
459        // Insert same literal twice
460        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
461        let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
462
463        assert_eq!(lit1, lit2); // Should be deduplicated
464        assert_eq!(arena.stats().dedup_hits, 1);
465    }
466
467    #[test]
468    fn test_ast_arena_binary_op() {
469        let mut arena = AstArena::new();
470
471        let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
472        let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
473        let add = arena.insert_binary_op("+", left, right);
474
475        match arena.get(add) {
476            Some(AstNodeData::BinaryOp {
477                op_id,
478                left_id,
479                right_id,
480            }) => {
481                assert_eq!(arena.resolve_string(*op_id), "+");
482                assert_eq!(*left_id, left);
483                assert_eq!(*right_id, right);
484            }
485            _ => panic!("Expected binary op node"),
486        }
487    }
488
489    #[test]
490    fn test_ast_arena_function() {
491        let mut arena = AstArena::new();
492
493        let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
494        let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
495        let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
496
497        let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
498
499        match arena.get(func) {
500            Some(AstNodeData::Function {
501                name_id,
502                args_count,
503                ..
504            }) => {
505                assert_eq!(arena.resolve_string(*name_id), "SUM");
506                assert_eq!(*args_count, 3);
507            }
508            _ => panic!("Expected function node"),
509        }
510
511        let args = arena.get_function_args(func).unwrap();
512        assert_eq!(args, &[arg1, arg2, arg3]);
513    }
514
515    #[test]
516    fn test_ast_arena_structural_sharing() {
517        let mut arena = AstArena::new();
518
519        // Create "A1" reference that will be shared
520        let a1_ref = arena.insert_reference(
521            "A1",
522            CompactRefType::Cell {
523                sheet: None,
524                row: 1,
525                col: 1,
526                row_abs: false,
527                col_abs: false,
528            },
529        );
530
531        // Create "A1 + 1"
532        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
533        let expr1 = arena.insert_binary_op("+", a1_ref, one);
534
535        // Create "A1 * 2"
536        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
537        let expr2 = arena.insert_binary_op("*", a1_ref, two);
538
539        // A1 reference should be shared
540        assert_eq!(arena.stats().node_count, 5); // A1, 1, +expr, 2, *expr
541
542        // Try to insert A1 again - should be deduplicated
543        let a1_ref2 = arena.insert_reference(
544            "A1",
545            CompactRefType::Cell {
546                sheet: None,
547                row: 1,
548                col: 1,
549                row_abs: false,
550                col_abs: false,
551            },
552        );
553        assert_eq!(a1_ref, a1_ref2);
554    }
555
556    #[test]
557    fn test_ast_arena_array() {
558        let mut arena = AstArena::new();
559
560        let elements = vec![
561            arena.insert_literal(ValueRef::small_int(1).unwrap()),
562            arena.insert_literal(ValueRef::small_int(2).unwrap()),
563            arena.insert_literal(ValueRef::small_int(3).unwrap()),
564            arena.insert_literal(ValueRef::small_int(4).unwrap()),
565        ];
566
567        let array = arena.insert_array(2, 2, elements.clone());
568
569        match arena.get(array) {
570            Some(AstNodeData::Array { rows, cols, .. }) => {
571                assert_eq!(*rows, 2);
572                assert_eq!(*cols, 2);
573            }
574            _ => panic!("Expected array node"),
575        }
576
577        let stored_elements = arena.get_array_elements(array).unwrap();
578        assert_eq!(stored_elements, &elements[..]);
579    }
580
581    #[test]
582    fn test_ast_arena_complex_expression() {
583        let mut arena = AstArena::new();
584
585        // Build: SUM(A1:A10) + IF(B1 > 0, C1, D1)
586
587        // A1:A10 range
588        let range = arena.insert_reference(
589            "A1:A10",
590            CompactRefType::Range {
591                sheet: None,
592                start_row: 1,
593                start_col: 1,
594                end_row: 10,
595                end_col: 1,
596                start_row_abs: false,
597                start_col_abs: false,
598                end_row_abs: false,
599                end_col_abs: false,
600            },
601        );
602
603        // SUM(A1:A10)
604        let sum = arena.insert_function("SUM", vec![range]);
605
606        // B1 reference
607        let b1 = arena.insert_reference(
608            "B1",
609            CompactRefType::Cell {
610                sheet: None,
611                row: 1,
612                col: 2,
613                row_abs: false,
614                col_abs: false,
615            },
616        );
617
618        // 0 literal
619        let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
620
621        // B1 > 0
622        let condition = arena.insert_binary_op(">", b1, zero);
623
624        // C1 and D1 references
625        let c1 = arena.insert_reference(
626            "C1",
627            CompactRefType::Cell {
628                sheet: None,
629                row: 1,
630                col: 3,
631                row_abs: false,
632                col_abs: false,
633            },
634        );
635        let d1 = arena.insert_reference(
636            "D1",
637            CompactRefType::Cell {
638                sheet: None,
639                row: 1,
640                col: 4,
641                row_abs: false,
642                col_abs: false,
643            },
644        );
645
646        // IF(B1 > 0, C1, D1)
647        let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
648
649        // Final: SUM(...) + IF(...)
650        let final_expr = arena.insert_binary_op("+", sum, if_expr);
651
652        // Verify structure
653        assert!(arena.get(final_expr).is_some());
654        // Note: zero literal gets deduplicated if used multiple times
655        // We have: range, sum, b1, zero, condition(>), c1, d1, if_expr, final_expr(+)
656        // That's 9 unique nodes (zero is deduplicated)
657        assert_eq!(arena.stats().node_count, 9); // All unique nodes except deduplicated zero
658    }
659
660    #[test]
661    fn test_ast_arena_string_deduplication() {
662        let mut arena = AstArena::new();
663
664        // Use same operator multiple times
665        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
666        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
667        let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
668
669        let add1 = arena.insert_binary_op("+", one, two);
670        let add2 = arena.insert_binary_op("+", two, three);
671        let add3 = arena.insert_binary_op("+", one, three);
672
673        // "+" should be interned only once
674        assert_eq!(arena.strings().len(), 1);
675    }
676
677    #[test]
678    fn test_ast_arena_clear() {
679        let mut arena = AstArena::new();
680
681        arena.insert_literal(ValueRef::small_int(1).unwrap());
682        arena.insert_literal(ValueRef::small_int(2).unwrap());
683        let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
684        let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
685        arena.insert_binary_op("+", left, right);
686
687        assert_eq!(arena.stats().node_count, 5);
688
689        arena.clear();
690
691        assert_eq!(arena.stats().node_count, 0);
692        assert_eq!(arena.strings().len(), 0);
693    }
694}