formualizer_eval/engine/arena/
ast.rs

1/// AST arena with structural sharing and deduplication
2/// Stores formula AST nodes efficiently with content-addressable storage
3use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::TableSpecifier;
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11/// Reference to an AST node in the arena
12#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16    pub fn as_u32(self) -> u32 {
17        self.0
18    }
19}
20
21impl fmt::Display for AstNodeId {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        write!(f, "AstNode({})", self.0)
24    }
25}
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
28pub struct TableSpecId(u32);
29
30impl TableSpecId {
31    pub fn as_u32(self) -> u32 {
32        self.0
33    }
34}
35
36/// Compact representation of AST nodes in the arena
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum AstNodeData {
39    /// Literal value
40    Literal(ValueRef),
41
42    /// Cell or range reference
43    Reference {
44        original_id: StringId,    // Original reference string
45        ref_type: CompactRefType, // Compact reference representation
46    },
47
48    /// Unary operation
49    UnaryOp { op_id: StringId, expr_id: AstNodeId },
50
51    /// Binary operation
52    BinaryOp {
53        op_id: StringId,
54        left_id: AstNodeId,
55        right_id: AstNodeId,
56    },
57
58    /// Function call
59    Function {
60        name_id: StringId,
61        args_offset: u32, // Index into args array
62        args_count: u16,  // Number of arguments
63    },
64
65    /// Array literal
66    Array {
67        rows: u16,
68        cols: u16,
69        elements_offset: u32, // Index into elements array
70    },
71}
72
73/// Identifies a sheet either by stable registry id or by unresolved name.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SheetKey {
76    Id(u16),
77    Name(StringId),
78}
79
80/// Compact representation of reference types
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub enum CompactRefType {
83    Cell {
84        sheet: Option<SheetKey>,
85        row: u32,
86        col: u32,
87    },
88    Range {
89        sheet: Option<SheetKey>,
90        start_row: u32,
91        start_col: u32,
92        end_row: u32,
93        end_col: u32,
94    },
95    NamedRange(StringId),
96    Table {
97        name_id: StringId,
98        specifier_id: Option<TableSpecId>,
99    },
100}
101
102/// Arena for storing AST nodes with deduplication
103pub struct AstArena {
104    /// Node storage
105    nodes: Vec<AstNodeData>,
106
107    /// Hash -> node index for deduplication
108    dedup_map: FxHashMap<u64, AstNodeId>,
109
110    /// Function arguments storage (flattened)
111    function_args: Vec<AstNodeId>,
112
113    /// Array elements storage (flattened)
114    array_elements: Vec<AstNodeId>,
115
116    /// String pool for operators and function names
117    strings: StringInterner,
118
119    /// Structured table specifiers
120    table_specs: Vec<TableSpecifier>,
121    table_spec_dedup: FxHashMap<u64, TableSpecId>,
122
123    /// Statistics
124    dedup_hits: usize,
125}
126
127impl AstArena {
128    pub fn new() -> Self {
129        Self {
130            nodes: Vec::new(),
131            dedup_map: FxHashMap::default(),
132            function_args: Vec::new(),
133            array_elements: Vec::new(),
134            strings: StringInterner::new(),
135            table_specs: Vec::new(),
136            table_spec_dedup: FxHashMap::default(),
137            dedup_hits: 0,
138        }
139    }
140
141    pub fn with_capacity(node_cap: usize) -> Self {
142        Self {
143            nodes: Vec::with_capacity(node_cap),
144            dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
145            function_args: Vec::with_capacity(node_cap * 2), // Assume avg 2 args
146            array_elements: Vec::with_capacity(node_cap),
147            strings: StringInterner::with_capacity(node_cap / 10),
148            table_specs: Vec::new(),
149            table_spec_dedup: FxHashMap::default(),
150            dedup_hits: 0,
151        }
152    }
153
154    /// Insert a node, deduplicating if it already exists
155    pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
156        // Compute hash
157        let hash = self.hash_node(&node);
158
159        // Check for existing node
160        if let Some(&id) = self.dedup_map.get(&hash) {
161            // Verify it's actually the same (handle hash collisions)
162            if self.nodes[id.0 as usize] == node {
163                self.dedup_hits += 1;
164                return id;
165            }
166        }
167
168        // Add new node
169        let id = AstNodeId(self.nodes.len() as u32);
170        self.nodes.push(node);
171        self.dedup_map.insert(hash, id);
172        id
173    }
174
175    /// Insert a literal node
176    pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
177        self.insert(AstNodeData::Literal(value))
178    }
179
180    /// Insert a reference node
181    pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
182        let original_id = self.strings.intern(original);
183        self.insert(AstNodeData::Reference {
184            original_id,
185            ref_type,
186        })
187    }
188
189    /// Insert a unary operation node
190    pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
191        let op_id = self.strings.intern(op);
192        self.insert(AstNodeData::UnaryOp {
193            op_id,
194            expr_id: expr,
195        })
196    }
197
198    /// Insert a binary operation node
199    pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
200        let op_id = self.strings.intern(op);
201        self.insert(AstNodeData::BinaryOp {
202            op_id,
203            left_id: left,
204            right_id: right,
205        })
206    }
207
208    /// Insert a function call node
209    pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
210        let name_id = self.strings.intern(name);
211        let args_offset = self.function_args.len() as u32;
212        let args_count = args.len() as u16;
213
214        self.function_args.extend(args);
215
216        self.insert(AstNodeData::Function {
217            name_id,
218            args_offset,
219            args_count,
220        })
221    }
222
223    /// Insert an array literal node
224    pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
225        assert_eq!(
226            elements.len(),
227            (rows * cols) as usize,
228            "Array dimensions don't match element count"
229        );
230
231        let elements_offset = self.array_elements.len() as u32;
232        self.array_elements.extend(elements);
233
234        self.insert(AstNodeData::Array {
235            rows,
236            cols,
237            elements_offset,
238        })
239    }
240
241    /// Get a node by ID
242    pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
243        self.nodes.get(id.0 as usize)
244    }
245
246    /// Get function arguments for a function node
247    pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
248        match self.get(id)? {
249            AstNodeData::Function {
250                args_offset,
251                args_count,
252                ..
253            } => {
254                let start = *args_offset as usize;
255                let end = start + *args_count as usize;
256                Some(&self.function_args[start..end])
257            }
258            _ => None,
259        }
260    }
261
262    /// Get array elements for an array node
263    pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
264        match self.get(id)? {
265            AstNodeData::Array {
266                rows,
267                cols,
268                elements_offset,
269            } => {
270                let start = *elements_offset as usize;
271                let count = (*rows * *cols) as usize;
272                let end = start + count;
273                Some(&self.array_elements[start..end])
274            }
275            _ => None,
276        }
277    }
278
279    pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
280        match self.get(id)? {
281            AstNodeData::Array { rows, cols, .. } => {
282                let elements = self.get_array_elements(id)?;
283                Some((*rows, *cols, elements))
284            }
285            _ => None,
286        }
287    }
288
289    /// Resolve a string ID to its content
290    pub fn resolve_string(&self, id: StringId) -> &str {
291        self.strings.resolve(id)
292    }
293
294    /// Get the string interner (for external use)
295    pub fn strings(&self) -> &StringInterner {
296        &self.strings
297    }
298
299    /// Get mutable access to the string interner
300    pub fn strings_mut(&mut self) -> &mut StringInterner {
301        &mut self.strings
302    }
303
304    pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
305        let hash = {
306            let mut hasher = DefaultHasher::new();
307            specifier.hash(&mut hasher);
308            hasher.finish()
309        };
310
311        if let Some(&id) = self.table_spec_dedup.get(&hash)
312            && self
313                .table_specs
314                .get(id.0 as usize)
315                .is_some_and(|existing| existing == specifier)
316        {
317            return id;
318        }
319
320        let id = TableSpecId(self.table_specs.len() as u32);
321        self.table_specs.push(specifier.clone());
322        self.table_spec_dedup.insert(hash, id);
323        id
324    }
325
326    pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
327        self.table_specs.get(id.0 as usize)
328    }
329
330    /// Compute hash for a node
331    fn hash_node(&self, node: &AstNodeData) -> u64 {
332        let mut hasher = DefaultHasher::new();
333        node.hash(&mut hasher);
334        hasher.finish()
335    }
336
337    /// Get statistics about the arena
338    pub fn stats(&self) -> AstArenaStats {
339        AstArenaStats {
340            node_count: self.nodes.len(),
341            dedup_hits: self.dedup_hits,
342            string_count: self.strings.len(),
343            table_spec_count: self.table_specs.len(),
344            total_args: self.function_args.len(),
345            total_array_elements: self.array_elements.len(),
346        }
347    }
348
349    /// Returns memory usage in bytes (approximate)
350    pub fn memory_usage(&self) -> usize {
351        self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
352            + self.dedup_map.capacity() * (8 + 4) // hash + id
353            + self.function_args.capacity() * 4
354            + self.array_elements.capacity() * 4
355            + self.strings.memory_usage()
356            + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
357            + self.table_spec_dedup.capacity() * (8 + 4)
358    }
359
360    /// Clear all nodes from the arena
361    pub fn clear(&mut self) {
362        self.nodes.clear();
363        self.dedup_map.clear();
364        self.function_args.clear();
365        self.array_elements.clear();
366        self.strings.clear();
367        self.table_specs.clear();
368        self.table_spec_dedup.clear();
369        self.dedup_hits = 0;
370    }
371}
372
373impl Default for AstArena {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379impl fmt::Debug for AstArena {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        f.debug_struct("AstArena")
382            .field("nodes", &self.nodes.len())
383            .field("dedup_hits", &self.dedup_hits)
384            .field("strings", &self.strings.len())
385            .finish()
386    }
387}
388
389/// Statistics about the AST arena
390#[derive(Debug, Clone)]
391pub struct AstArenaStats {
392    pub node_count: usize,
393    pub dedup_hits: usize,
394    pub string_count: usize,
395    pub table_spec_count: usize,
396    pub total_args: usize,
397    pub total_array_elements: usize,
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_ast_arena_literal() {
406        let mut arena = AstArena::new();
407
408        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
409        let lit2 = arena.insert_literal(ValueRef::boolean(true));
410
411        assert_ne!(lit1, lit2);
412
413        match arena.get(lit1) {
414            Some(AstNodeData::Literal(v)) => {
415                assert_eq!(v.as_small_int(), Some(42));
416            }
417            _ => panic!("Expected literal node"),
418        }
419    }
420
421    #[test]
422    fn test_ast_arena_deduplication() {
423        let mut arena = AstArena::new();
424
425        // Insert same literal twice
426        let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
427        let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
428
429        assert_eq!(lit1, lit2); // Should be deduplicated
430        assert_eq!(arena.stats().dedup_hits, 1);
431    }
432
433    #[test]
434    fn test_ast_arena_binary_op() {
435        let mut arena = AstArena::new();
436
437        let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
438        let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
439        let add = arena.insert_binary_op("+", left, right);
440
441        match arena.get(add) {
442            Some(AstNodeData::BinaryOp {
443                op_id,
444                left_id,
445                right_id,
446            }) => {
447                assert_eq!(arena.resolve_string(*op_id), "+");
448                assert_eq!(*left_id, left);
449                assert_eq!(*right_id, right);
450            }
451            _ => panic!("Expected binary op node"),
452        }
453    }
454
455    #[test]
456    fn test_ast_arena_function() {
457        let mut arena = AstArena::new();
458
459        let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
460        let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
461        let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
462
463        let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
464
465        match arena.get(func) {
466            Some(AstNodeData::Function {
467                name_id,
468                args_count,
469                ..
470            }) => {
471                assert_eq!(arena.resolve_string(*name_id), "SUM");
472                assert_eq!(*args_count, 3);
473            }
474            _ => panic!("Expected function node"),
475        }
476
477        let args = arena.get_function_args(func).unwrap();
478        assert_eq!(args, &[arg1, arg2, arg3]);
479    }
480
481    #[test]
482    fn test_ast_arena_structural_sharing() {
483        let mut arena = AstArena::new();
484
485        // Create "A1" reference that will be shared
486        let a1_ref = arena.insert_reference(
487            "A1",
488            CompactRefType::Cell {
489                sheet: None,
490                row: 1,
491                col: 1,
492            },
493        );
494
495        // Create "A1 + 1"
496        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
497        let expr1 = arena.insert_binary_op("+", a1_ref, one);
498
499        // Create "A1 * 2"
500        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
501        let expr2 = arena.insert_binary_op("*", a1_ref, two);
502
503        // A1 reference should be shared
504        assert_eq!(arena.stats().node_count, 5); // A1, 1, +expr, 2, *expr
505
506        // Try to insert A1 again - should be deduplicated
507        let a1_ref2 = arena.insert_reference(
508            "A1",
509            CompactRefType::Cell {
510                sheet: None,
511                row: 1,
512                col: 1,
513            },
514        );
515        assert_eq!(a1_ref, a1_ref2);
516    }
517
518    #[test]
519    fn test_ast_arena_array() {
520        let mut arena = AstArena::new();
521
522        let elements = vec![
523            arena.insert_literal(ValueRef::small_int(1).unwrap()),
524            arena.insert_literal(ValueRef::small_int(2).unwrap()),
525            arena.insert_literal(ValueRef::small_int(3).unwrap()),
526            arena.insert_literal(ValueRef::small_int(4).unwrap()),
527        ];
528
529        let array = arena.insert_array(2, 2, elements.clone());
530
531        match arena.get(array) {
532            Some(AstNodeData::Array { rows, cols, .. }) => {
533                assert_eq!(*rows, 2);
534                assert_eq!(*cols, 2);
535            }
536            _ => panic!("Expected array node"),
537        }
538
539        let stored_elements = arena.get_array_elements(array).unwrap();
540        assert_eq!(stored_elements, &elements[..]);
541    }
542
543    #[test]
544    fn test_ast_arena_complex_expression() {
545        let mut arena = AstArena::new();
546
547        // Build: SUM(A1:A10) + IF(B1 > 0, C1, D1)
548
549        // A1:A10 range
550        let range = arena.insert_reference(
551            "A1:A10",
552            CompactRefType::Range {
553                sheet: None,
554                start_row: 1,
555                start_col: 1,
556                end_row: 10,
557                end_col: 1,
558            },
559        );
560
561        // SUM(A1:A10)
562        let sum = arena.insert_function("SUM", vec![range]);
563
564        // B1 reference
565        let b1 = arena.insert_reference(
566            "B1",
567            CompactRefType::Cell {
568                sheet: None,
569                row: 1,
570                col: 2,
571            },
572        );
573
574        // 0 literal
575        let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
576
577        // B1 > 0
578        let condition = arena.insert_binary_op(">", b1, zero);
579
580        // C1 and D1 references
581        let c1 = arena.insert_reference(
582            "C1",
583            CompactRefType::Cell {
584                sheet: None,
585                row: 1,
586                col: 3,
587            },
588        );
589        let d1 = arena.insert_reference(
590            "D1",
591            CompactRefType::Cell {
592                sheet: None,
593                row: 1,
594                col: 4,
595            },
596        );
597
598        // IF(B1 > 0, C1, D1)
599        let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
600
601        // Final: SUM(...) + IF(...)
602        let final_expr = arena.insert_binary_op("+", sum, if_expr);
603
604        // Verify structure
605        assert!(arena.get(final_expr).is_some());
606        // Note: zero literal gets deduplicated if used multiple times
607        // We have: range, sum, b1, zero, condition(>), c1, d1, if_expr, final_expr(+)
608        // That's 9 unique nodes (zero is deduplicated)
609        assert_eq!(arena.stats().node_count, 9); // All unique nodes except deduplicated zero
610    }
611
612    #[test]
613    fn test_ast_arena_string_deduplication() {
614        let mut arena = AstArena::new();
615
616        // Use same operator multiple times
617        let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
618        let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
619        let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
620
621        let add1 = arena.insert_binary_op("+", one, two);
622        let add2 = arena.insert_binary_op("+", two, three);
623        let add3 = arena.insert_binary_op("+", one, three);
624
625        // "+" should be interned only once
626        assert_eq!(arena.strings().len(), 1);
627    }
628
629    #[test]
630    fn test_ast_arena_clear() {
631        let mut arena = AstArena::new();
632
633        arena.insert_literal(ValueRef::small_int(1).unwrap());
634        arena.insert_literal(ValueRef::small_int(2).unwrap());
635        let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
636        let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
637        arena.insert_binary_op("+", left, right);
638
639        assert_eq!(arena.stats().node_count, 5);
640
641        arena.clear();
642
643        assert_eq!(arena.stats().node_count, 0);
644        assert_eq!(arena.strings().len(), 0);
645    }
646}