Skip to main content

formualizer_eval/engine/arena/
ast.rs

1/// AST arena with structural sharing and deduplication
2/// Stores formula AST nodes efficiently with content-addressable storage
3use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::{ExternalRefKind, TableSpecifier};
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11/// Reference to an AST node in the arena
12#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16    pub fn as_u32(self) -> u32 {
17        self.0
18    }
19}
20
21impl fmt::Display for AstNodeId {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(f, "AstNode({})", self.0)
24    }
25}
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
28pub struct TableSpecId(u32);
29
30impl TableSpecId {
31    pub fn as_u32(self) -> u32 {
32        self.0
33    }
34}
35
36/// Compact representation of AST nodes in the arena
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum AstNodeData {
39    /// Literal value
40    Literal(ValueRef),
41
42    /// Cell or range reference
43    Reference {
44        original_id: StringId,    // Original reference string
45        ref_type: CompactRefType, // Compact reference representation
46    },
47
48    /// Unary operation
49    UnaryOp { op_id: StringId, expr_id: AstNodeId },
50
51    /// Binary operation
52    BinaryOp {
53        op_id: StringId,
54        left_id: AstNodeId,
55        right_id: AstNodeId,
56    },
57
58    /// Function call
59    Function {
60        name_id: StringId,
61        args_offset: u32, // Index into args array
62        args_count: u16,  // Number of arguments
63    },
64
65    /// Array literal
66    Array {
67        rows: u16,
68        cols: u16,
69        elements_offset: u32, // Index into elements array
70    },
71}
72
73/// Identifies a sheet either by stable registry id or by unresolved name.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SheetKey {
76    Id(u16),
77    Name(StringId),
78}
79
80/// Compact representation of reference types
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub enum CompactRefType {
83    Cell {
84        sheet: Option<SheetKey>,
85        row: u32,
86        col: u32,
87        row_abs: bool,
88        col_abs: bool,
89    },
90    Range {
91        sheet: Option<SheetKey>,
92        start_row: u32,
93        start_col: u32,
94        end_row: u32,
95        end_col: u32,
96        start_row_abs: bool,
97        start_col_abs: bool,
98        end_row_abs: bool,
99        end_col_abs: bool,
100    },
101    External {
102        raw_id: StringId,
103        book_id: StringId,
104        sheet_id: StringId,
105        kind: ExternalRefKind,
106    },
107    NamedRange(StringId),
108    Table {
109        name_id: StringId,
110        specifier_id: Option<TableSpecId>,
111    },
112}
113
114/// Arena for storing AST nodes with deduplication
115pub struct AstArena {
116    /// Node storage
117    nodes: Vec<AstNodeData>,
118
119    /// Hash -> node index for deduplication
120    dedup_map: FxHashMap<u64, AstNodeId>,
121
122    /// Function arguments storage (flattened)
123    function_args: Vec<AstNodeId>,
124
125    /// Array elements storage (flattened)
126    array_elements: Vec<AstNodeId>,
127
128    /// String pool for operators and function names
129    strings: StringInterner,
130
131    /// Structured table specifiers
132    table_specs: Vec<TableSpecifier>,
133    table_spec_dedup: FxHashMap<u64, TableSpecId>,
134
135    /// Statistics
136    dedup_hits: usize,
137}
138
139impl AstArena {
140    pub fn new() -> Self {
141        Self {
142            nodes: Vec::new(),
143            dedup_map: FxHashMap::default(),
144            function_args: Vec::new(),
145            array_elements: Vec::new(),
146            strings: StringInterner::new(),
147            table_specs: Vec::new(),
148            table_spec_dedup: FxHashMap::default(),
149            dedup_hits: 0,
150        }
151    }
152
153    pub fn with_capacity(node_cap: usize) -> Self {
154        Self {
155            nodes: Vec::with_capacity(node_cap),
156            dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
157            function_args: Vec::with_capacity(node_cap * 2), // Assume avg 2 args
158            array_elements: Vec::with_capacity(node_cap),
159            strings: StringInterner::with_capacity(node_cap / 10),
160            table_specs: Vec::new(),
161            table_spec_dedup: FxHashMap::default(),
162            dedup_hits: 0,
163        }
164    }
165
166    /// Insert a node, deduplicating if it already exists
167    pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
168        // Compute hash
169        let hash = self.hash_node(&node);
170
171        // Check for existing node
172        if let Some(&id) = self.dedup_map.get(&hash) {
173            // Verify it's actually the same (handle hash collisions)
174            if self.nodes[id.0 as usize] == node {
175                self.dedup_hits += 1;
176                return id;
177            }
178        }
179
180        // Add new node
181        let id = AstNodeId(self.nodes.len() as u32);
182        self.nodes.push(node);
183        self.dedup_map.insert(hash, id);
184        id
185    }
186
187    /// Insert a literal node
188    pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
189        self.insert(AstNodeData::Literal(value))
190    }
191
192    /// Insert a reference node
193    pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
194        let original_id = self.strings.intern(original);
195        self.insert(AstNodeData::Reference {
196            original_id,
197            ref_type,
198        })
199    }
200
201    /// Insert a unary operation node
202    pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
203        let op_id = self.strings.intern(op);
204        self.insert(AstNodeData::UnaryOp {
205            op_id,
206            expr_id: expr,
207        })
208    }
209
210    /// Insert a binary operation node
211    pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
212        let op_id = self.strings.intern(op);
213        self.insert(AstNodeData::BinaryOp {
214            op_id,
215            left_id: left,
216            right_id: right,
217        })
218    }
219
220    /// Insert a function call node
221    pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
222        let name_id = self.strings.intern(name);
223        let args_offset = self.function_args.len() as u32;
224        let args_count = args.len() as u16;
225
226        self.function_args.extend(args);
227
228        self.insert(AstNodeData::Function {
229            name_id,
230            args_offset,
231            args_count,
232        })
233    }
234
235    /// Insert an array literal node
236    pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
237        assert_eq!(
238            elements.len(),
239            (rows * cols) as usize,
240            "Array dimensions don't match element count"
241        );
242
243        let elements_offset = self.array_elements.len() as u32;
244        self.array_elements.extend(elements);
245
246        self.insert(AstNodeData::Array {
247            rows,
248            cols,
249            elements_offset,
250        })
251    }
252
253    /// Get a node by ID
254    pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
255        self.nodes.get(id.0 as usize)
256    }
257
258    /// Get function arguments for a function node
259    pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
260        match self.get(id)? {
261            AstNodeData::Function {
262                args_offset,
263                args_count,
264                ..
265            } => {
266                let start = *args_offset as usize;
267                let end = start + *args_count as usize;
268                Some(&self.function_args[start..end])
269            }
270            _ => None,
271        }
272    }
273
274    /// Get array elements for an array node
275    pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
276        match self.get(id)? {
277            AstNodeData::Array {
278                rows,
279                cols,
280                elements_offset,
281            } => {
282                let start = *elements_offset as usize;
283                let count = (*rows * *cols) as usize;
284                let end = start + count;
285                Some(&self.array_elements[start..end])
286            }
287            _ => None,
288        }
289    }
290
291    pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
292        match self.get(id)? {
293            AstNodeData::Array { rows, cols, .. } => {
294                let elements = self.get_array_elements(id)?;
295                Some((*rows, *cols, elements))
296            }
297            _ => None,
298        }
299    }
300
301    /// Resolve a string ID to its content
302    pub fn resolve_string(&self, id: StringId) -> &str {
303        self.strings.resolve(id)
304    }
305
306    /// Get the string interner (for external use)
307    pub fn strings(&self) -> &StringInterner {
308        &self.strings
309    }
310
311    /// Get mutable access to the string interner
312    pub fn strings_mut(&mut self) -> &mut StringInterner {
313        &mut self.strings
314    }
315
316    pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
317        let hash = {
318            let mut hasher = DefaultHasher::new();
319            specifier.hash(&mut hasher);
320            hasher.finish()
321        };
322
323        if let Some(&id) = self.table_spec_dedup.get(&hash)
324            && self
325                .table_specs
326                .get(id.0 as usize)
327                .is_some_and(|existing| existing == specifier)
328        {
329            return id;
330        }
331
332        let id = TableSpecId(self.table_specs.len() as u32);
333        self.table_specs.push(specifier.clone());
334        self.table_spec_dedup.insert(hash, id);
335        id
336    }
337
338    pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
339        self.table_specs.get(id.0 as usize)
340    }
341
342    /// Compute hash for a node
343    fn hash_node(&self, node: &AstNodeData) -> u64 {
344        let mut hasher = DefaultHasher::new();
345        node.hash(&mut hasher);
346        hasher.finish()
347    }
348
349    /// Get statistics about the arena
350    pub fn stats(&self) -> AstArenaStats {
351        AstArenaStats {
352            node_count: self.nodes.len(),
353            dedup_hits: self.dedup_hits,
354            string_count: self.strings.len(),
355            table_spec_count: self.table_specs.len(),
356            total_args: self.function_args.len(),
357            total_array_elements: self.array_elements.len(),
358        }
359    }
360
361    /// Returns memory usage in bytes (approximate)
362    pub fn memory_usage(&self) -> usize {
363        self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
364            + self.dedup_map.capacity() * (8 + 4) // hash + id
365            + self.function_args.capacity() * 4
366            + self.array_elements.capacity() * 4
367            + self.strings.memory_usage()
368            + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
369            + self.table_spec_dedup.capacity() * (8 + 4)
370    }
371
372    /// Clear all nodes from the arena
373    pub fn clear(&mut self) {
374        self.nodes.clear();
375        self.dedup_map.clear();
376        self.function_args.clear();
377        self.array_elements.clear();
378        self.strings.clear();
379        self.table_specs.clear();
380        self.table_spec_dedup.clear();
381        self.dedup_hits = 0;
382    }
383}
384
385impl Default for AstArena {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391impl fmt::Debug for AstArena {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        f.debug_struct("AstArena")
394            .field("nodes", &self.nodes.len())
395            .field("dedup_hits", &self.dedup_hits)
396            .field("strings", &self.strings.len())
397            .finish()
398    }
399}
400
401/// Statistics about the AST arena
402#[derive(Debug, Clone)]
403pub struct AstArenaStats {
404    pub node_count: usize,
405    pub dedup_hits: usize,
406    pub string_count: usize,
407    pub table_spec_count: usize,
408    pub total_args: usize,
409    pub total_array_elements: usize,
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_ast_arena_literal() {
418        let mut arena = AstArena::new();
419
420        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
421        let lit2 = arena.insert_literal(ValueRef::boolean(true));
422
423        assert_ne!(lit1, lit2);
424
425        match arena.get(lit1) {
426            Some(AstNodeData::Literal(v)) => {
427                assert_eq!(v.as_small_int(), Some(42));
428            }
429            _ => panic!("Expected literal node"),
430        }
431    }
432
433    #[test]
434    fn test_ast_arena_deduplication() {
435        let mut arena = AstArena::new();
436
437        // Insert same literal twice
438        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
439        let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
440
441        assert_eq!(lit1, lit2); // Should be deduplicated
442        assert_eq!(arena.stats().dedup_hits, 1);
443    }
444
445    #[test]
446    fn test_ast_arena_binary_op() {
447        let mut arena = AstArena::new();
448
449        let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
450        let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
451        let add = arena.insert_binary_op("+", left, right);
452
453        match arena.get(add) {
454            Some(AstNodeData::BinaryOp {
455                op_id,
456                left_id,
457                right_id,
458            }) => {
459                assert_eq!(arena.resolve_string(*op_id), "+");
460                assert_eq!(*left_id, left);
461                assert_eq!(*right_id, right);
462            }
463            _ => panic!("Expected binary op node"),
464        }
465    }
466
467    #[test]
468    fn test_ast_arena_function() {
469        let mut arena = AstArena::new();
470
471        let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
472        let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
473        let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
474
475        let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
476
477        match arena.get(func) {
478            Some(AstNodeData::Function {
479                name_id,
480                args_count,
481                ..
482            }) => {
483                assert_eq!(arena.resolve_string(*name_id), "SUM");
484                assert_eq!(*args_count, 3);
485            }
486            _ => panic!("Expected function node"),
487        }
488
489        let args = arena.get_function_args(func).unwrap();
490        assert_eq!(args, &[arg1, arg2, arg3]);
491    }
492
493    #[test]
494    fn test_ast_arena_structural_sharing() {
495        let mut arena = AstArena::new();
496
497        // Create "A1" reference that will be shared
498        let a1_ref = arena.insert_reference(
499            "A1",
500            CompactRefType::Cell {
501                sheet: None,
502                row: 1,
503                col: 1,
504                row_abs: false,
505                col_abs: false,
506            },
507        );
508
509        // Create "A1 + 1"
510        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
511        let expr1 = arena.insert_binary_op("+", a1_ref, one);
512
513        // Create "A1 * 2"
514        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
515        let expr2 = arena.insert_binary_op("*", a1_ref, two);
516
517        // A1 reference should be shared
518        assert_eq!(arena.stats().node_count, 5); // A1, 1, +expr, 2, *expr
519
520        // Try to insert A1 again - should be deduplicated
521        let a1_ref2 = arena.insert_reference(
522            "A1",
523            CompactRefType::Cell {
524                sheet: None,
525                row: 1,
526                col: 1,
527                row_abs: false,
528                col_abs: false,
529            },
530        );
531        assert_eq!(a1_ref, a1_ref2);
532    }
533
534    #[test]
535    fn test_ast_arena_array() {
536        let mut arena = AstArena::new();
537
538        let elements = vec![
539            arena.insert_literal(ValueRef::small_int(1).unwrap()),
540            arena.insert_literal(ValueRef::small_int(2).unwrap()),
541            arena.insert_literal(ValueRef::small_int(3).unwrap()),
542            arena.insert_literal(ValueRef::small_int(4).unwrap()),
543        ];
544
545        let array = arena.insert_array(2, 2, elements.clone());
546
547        match arena.get(array) {
548            Some(AstNodeData::Array { rows, cols, .. }) => {
549                assert_eq!(*rows, 2);
550                assert_eq!(*cols, 2);
551            }
552            _ => panic!("Expected array node"),
553        }
554
555        let stored_elements = arena.get_array_elements(array).unwrap();
556        assert_eq!(stored_elements, &elements[..]);
557    }
558
559    #[test]
560    fn test_ast_arena_complex_expression() {
561        let mut arena = AstArena::new();
562
563        // Build: SUM(A1:A10) + IF(B1 > 0, C1, D1)
564
565        // A1:A10 range
566        let range = arena.insert_reference(
567            "A1:A10",
568            CompactRefType::Range {
569                sheet: None,
570                start_row: 1,
571                start_col: 1,
572                end_row: 10,
573                end_col: 1,
574                start_row_abs: false,
575                start_col_abs: false,
576                end_row_abs: false,
577                end_col_abs: false,
578            },
579        );
580
581        // SUM(A1:A10)
582        let sum = arena.insert_function("SUM", vec![range]);
583
584        // B1 reference
585        let b1 = arena.insert_reference(
586            "B1",
587            CompactRefType::Cell {
588                sheet: None,
589                row: 1,
590                col: 2,
591                row_abs: false,
592                col_abs: false,
593            },
594        );
595
596        // 0 literal
597        let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
598
599        // B1 > 0
600        let condition = arena.insert_binary_op(">", b1, zero);
601
602        // C1 and D1 references
603        let c1 = arena.insert_reference(
604            "C1",
605            CompactRefType::Cell {
606                sheet: None,
607                row: 1,
608                col: 3,
609                row_abs: false,
610                col_abs: false,
611            },
612        );
613        let d1 = arena.insert_reference(
614            "D1",
615            CompactRefType::Cell {
616                sheet: None,
617                row: 1,
618                col: 4,
619                row_abs: false,
620                col_abs: false,
621            },
622        );
623
624        // IF(B1 > 0, C1, D1)
625        let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
626
627        // Final: SUM(...) + IF(...)
628        let final_expr = arena.insert_binary_op("+", sum, if_expr);
629
630        // Verify structure
631        assert!(arena.get(final_expr).is_some());
632        // Note: zero literal gets deduplicated if used multiple times
633        // We have: range, sum, b1, zero, condition(>), c1, d1, if_expr, final_expr(+)
634        // That's 9 unique nodes (zero is deduplicated)
635        assert_eq!(arena.stats().node_count, 9); // All unique nodes except deduplicated zero
636    }
637
638    #[test]
639    fn test_ast_arena_string_deduplication() {
640        let mut arena = AstArena::new();
641
642        // Use same operator multiple times
643        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
644        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
645        let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
646
647        let add1 = arena.insert_binary_op("+", one, two);
648        let add2 = arena.insert_binary_op("+", two, three);
649        let add3 = arena.insert_binary_op("+", one, three);
650
651        // "+" should be interned only once
652        assert_eq!(arena.strings().len(), 1);
653    }
654
655    #[test]
656    fn test_ast_arena_clear() {
657        let mut arena = AstArena::new();
658
659        arena.insert_literal(ValueRef::small_int(1).unwrap());
660        arena.insert_literal(ValueRef::small_int(2).unwrap());
661        let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
662        let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
663        arena.insert_binary_op("+", left, right);
664
665        assert_eq!(arena.stats().node_count, 5);
666
667        arena.clear();
668
669        assert_eq!(arena.stats().node_count, 0);
670        assert_eq!(arena.strings().len(), 0);
671    }
672}