1use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::{ExternalRefKind, TableSpecifier};
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16 pub fn as_u32(self) -> u32 {
17 self.0
18 }
19
20 pub(crate) const fn from_u32(raw: u32) -> Self {
21 Self(raw)
22 }
23}
24
25impl fmt::Display for AstNodeId {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "AstNode({})", self.0)
28 }
29}
30
31#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
32pub struct TableSpecId(u32);
33
34impl TableSpecId {
35 pub fn as_u32(self) -> u32 {
36 self.0
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub enum AstNodeData {
43 Literal(ValueRef),
45
46 Reference {
48 original_id: StringId, ref_type: CompactRefType, },
51
52 UnaryOp { op_id: StringId, expr_id: AstNodeId },
54
55 BinaryOp {
57 op_id: StringId,
58 left_id: AstNodeId,
59 right_id: AstNodeId,
60 },
61
62 Function {
64 name_id: StringId,
65 args_offset: u32, args_count: u16, },
68
69 Array {
71 rows: u16,
72 cols: u16,
73 elements_offset: u32, },
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum SheetKey {
80 Id(u16),
81 Name(StringId),
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86pub enum CompactRefType {
87 Cell {
88 sheet: Option<SheetKey>,
89 row: u32,
90 col: u32,
91 row_abs: bool,
92 col_abs: bool,
93 },
94 Range {
95 sheet: Option<SheetKey>,
96 start_row: u32,
97 start_col: u32,
98 end_row: u32,
99 end_col: u32,
100 start_row_abs: bool,
101 start_col_abs: bool,
102 end_row_abs: bool,
103 end_col_abs: bool,
104 },
105 External {
106 raw_id: StringId,
107 book_id: StringId,
108 sheet_id: StringId,
109 kind: ExternalRefKind,
110 },
111 NamedRange(StringId),
112 Table {
113 name_id: StringId,
114 specifier_id: Option<TableSpecId>,
115 },
116 Cell3D {
118 sheet_first: StringId,
119 sheet_last: StringId,
120 row: u32,
121 col: u32,
122 row_abs: bool,
123 col_abs: bool,
124 },
125 Range3D {
127 sheet_first: StringId,
128 sheet_last: StringId,
129 start_row: u32,
130 start_col: u32,
131 end_row: u32,
132 end_col: u32,
133 start_row_abs: bool,
134 start_col_abs: bool,
135 end_row_abs: bool,
136 end_col_abs: bool,
137 },
138}
139
140#[derive(Debug, Clone, PartialEq, Eq, Hash)]
146pub(crate) struct AstNodeEntry {
147 pub(crate) data: AstNodeData,
148 pub(crate) meta: AstNodeMetadata,
149}
150
151#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
153pub(crate) struct AstNodeMetadata {
154 pub(crate) canonical_hash: u64,
155 pub(crate) labels: CanonicalLabels,
156}
157
158#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
160pub(crate) struct CanonicalLabels {
161 pub(crate) flags: u64,
162 pub(crate) rejects: u64,
163}
164
165#[allow(dead_code)]
166impl CanonicalLabels {
167 pub(crate) const FLAG_RELATIVE_ONLY: u64 = 1 << 0;
168 pub(crate) const FLAG_ABSOLUTE_ONLY: u64 = 1 << 1;
169 pub(crate) const FLAG_MIXED_ANCHORS: u64 = 1 << 2;
170 pub(crate) const FLAG_VOLATILE: u64 = 1 << 3;
171 pub(crate) const FLAG_DYNAMIC: u64 = 1 << 4;
172 pub(crate) const FLAG_CONTAINS_STRUCTURED_REF: u64 = 1 << 5;
173 pub(crate) const FLAG_NEEDS_PLACEMENT_REWRITE: u64 = 1 << 6;
174 pub(crate) const FLAG_CONTAINS_NAME: u64 = 1 << 7;
175 pub(crate) const FLAG_CONTAINS_TABLE: u64 = 1 << 8;
176 pub(crate) const FLAG_CONTAINS_RANGE: u64 = 1 << 9;
177 pub(crate) const FLAG_CONTAINS_ARRAY: u64 = 1 << 10;
178 pub(crate) const FLAG_CONTAINS_LET_LAMBDA: u64 = 1 << 11;
179 pub(crate) const FLAG_CONTAINS_FUNCTION: u64 = 1 << 12;
180 pub(crate) const FLAG_EXPLICIT_SHEET: u64 = 1 << 13;
181 pub(crate) const FLAG_CURRENT_SHEET: u64 = 1 << 14;
182
183 pub(crate) const REJECT_INVALID_PLACEMENT_ANCHOR: u64 = 1 << 0;
186 pub(crate) const REJECT_DYNAMIC_REFERENCE: u64 = 1 << 1;
187 pub(crate) const REJECT_UNKNOWN_OR_CUSTOM_FUNCTION: u64 = 1 << 2;
188 pub(crate) const REJECT_LOCAL_ENVIRONMENT: u64 = 1 << 3;
189 pub(crate) const REJECT_PARSER_VOLATILE_FLAG: u64 = 1 << 4;
190 pub(crate) const REJECT_VOLATILE_FUNCTION: u64 = 1 << 5;
191 pub(crate) const REJECT_REFERENCE_RETURNING_FUNCTION: u64 = 1 << 6;
192 pub(crate) const REJECT_ARRAY_OR_SPILL_FUNCTION: u64 = 1 << 7;
193 pub(crate) const REJECT_ARRAY_LITERAL: u64 = 1 << 8;
194 pub(crate) const REJECT_SPILL_REFERENCE: u64 = 1 << 9;
195 pub(crate) const REJECT_SPILL_RESULT_REGION_OPERATOR: u64 = 1 << 10;
196 pub(crate) const REJECT_IMPLICIT_INTERSECTION_OPERATOR: u64 = 1 << 11;
197 pub(crate) const REJECT_CALL_EXPRESSION: u64 = 1 << 12;
198 pub(crate) const REJECT_NAMED_REFERENCE: u64 = 1 << 13;
199 pub(crate) const REJECT_STRUCTURED_REFERENCE: u64 = 1 << 14;
200 pub(crate) const REJECT_STRUCTURED_REFERENCE_CURRENT_ROW: u64 = 1 << 15;
201 pub(crate) const REJECT_THREE_D_REFERENCE: u64 = 1 << 16;
202 pub(crate) const REJECT_EXTERNAL_REFERENCE: u64 = 1 << 17;
203 pub(crate) const REJECT_OPEN_RANGE_REFERENCE: u64 = 1 << 18;
204 pub(crate) const REJECT_WHOLE_AXIS_REFERENCE: u64 = 1 << 19;
205 pub(crate) const REJECT_UNSUPPORTED_REFERENCE: u64 = 1 << 20;
206
207 pub(crate) fn has_flag(self, flag: u64) -> bool {
208 self.flags & flag != 0
209 }
210
211 pub(crate) fn has_reject(self, reject: u64) -> bool {
212 self.rejects & reject != 0
213 }
214}
215
216pub struct AstArena {
218 nodes: Vec<AstNodeEntry>,
220
221 dedup_map: FxHashMap<u64, AstNodeId>,
223
224 function_args: Vec<AstNodeId>,
226
227 array_elements: Vec<AstNodeId>,
229
230 strings: StringInterner,
232
233 table_specs: Vec<TableSpecifier>,
235 table_spec_dedup: FxHashMap<u64, TableSpecId>,
236
237 dedup_hits: usize,
239}
240
241impl AstArena {
242 pub fn new() -> Self {
243 Self {
244 nodes: Vec::new(),
245 dedup_map: FxHashMap::default(),
246 function_args: Vec::new(),
247 array_elements: Vec::new(),
248 strings: StringInterner::new(),
249 table_specs: Vec::new(),
250 table_spec_dedup: FxHashMap::default(),
251 dedup_hits: 0,
252 }
253 }
254
255 pub fn with_capacity(node_cap: usize) -> Self {
256 Self {
257 nodes: Vec::with_capacity(node_cap),
258 dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
259 function_args: Vec::with_capacity(node_cap * 2), array_elements: Vec::with_capacity(node_cap),
261 strings: StringInterner::with_capacity(node_cap / 10),
262 table_specs: Vec::new(),
263 table_spec_dedup: FxHashMap::default(),
264 dedup_hits: 0,
265 }
266 }
267
268 pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
273 self.insert_entry(node, AstNodeMetadata::default())
274 }
275
276 #[allow(dead_code)]
283 pub(crate) fn insert_with_meta(
284 &mut self,
285 node: AstNodeData,
286 meta: AstNodeMetadata,
287 ) -> AstNodeId {
288 self.insert_entry(node, meta)
289 }
290
291 fn insert_entry(&mut self, node: AstNodeData, meta: AstNodeMetadata) -> AstNodeId {
292 let hash = self.hash_node(&node);
294
295 if let Some(&id) = self.dedup_map.get(&hash) {
297 if self.nodes[id.0 as usize].data == node {
299 self.dedup_hits += 1;
300 return id;
301 }
302 }
303
304 let id = AstNodeId(self.nodes.len() as u32);
306 self.nodes.push(AstNodeEntry { data: node, meta });
307 self.dedup_map.insert(hash, id);
308 id
309 }
310
311 pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
313 self.insert(AstNodeData::Literal(value))
314 }
315
316 pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
318 let original_id = self.strings.intern(original);
319 self.insert(AstNodeData::Reference {
320 original_id,
321 ref_type,
322 })
323 }
324
325 pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
327 let op_id = self.strings.intern(op);
328 self.insert(AstNodeData::UnaryOp {
329 op_id,
330 expr_id: expr,
331 })
332 }
333
334 pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
336 let op_id = self.strings.intern(op);
337 self.insert(AstNodeData::BinaryOp {
338 op_id,
339 left_id: left,
340 right_id: right,
341 })
342 }
343
344 pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
346 let name_id = self.strings.intern(name);
347 let args_offset = self.function_args.len() as u32;
348 let args_count = args.len() as u16;
349
350 self.function_args.extend(args);
351
352 self.insert(AstNodeData::Function {
353 name_id,
354 args_offset,
355 args_count,
356 })
357 }
358
359 pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
361 assert_eq!(
362 elements.len(),
363 (rows * cols) as usize,
364 "Array dimensions don't match element count"
365 );
366
367 let elements_offset = self.array_elements.len() as u32;
368 self.array_elements.extend(elements);
369
370 self.insert(AstNodeData::Array {
371 rows,
372 cols,
373 elements_offset,
374 })
375 }
376
377 pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
379 self.nodes.get(id.0 as usize).map(|entry| &entry.data)
380 }
381
382 #[allow(dead_code)]
384 pub(crate) fn entry(&self, id: AstNodeId) -> Option<&AstNodeEntry> {
385 self.nodes.get(id.0 as usize)
386 }
387
388 #[allow(dead_code)]
390 pub(crate) fn metadata(&self, id: AstNodeId) -> Option<AstNodeMetadata> {
391 self.entry(id).map(|entry| entry.meta)
392 }
393
394 pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
396 match self.get(id)? {
397 AstNodeData::Function {
398 args_offset,
399 args_count,
400 ..
401 } => {
402 let start = *args_offset as usize;
403 let end = start + *args_count as usize;
404 Some(&self.function_args[start..end])
405 }
406 _ => None,
407 }
408 }
409
410 pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
412 match self.get(id)? {
413 AstNodeData::Array {
414 rows,
415 cols,
416 elements_offset,
417 } => {
418 let start = *elements_offset as usize;
419 let count = (*rows * *cols) as usize;
420 let end = start + count;
421 Some(&self.array_elements[start..end])
422 }
423 _ => None,
424 }
425 }
426
427 pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
428 match self.get(id)? {
429 AstNodeData::Array { rows, cols, .. } => {
430 let elements = self.get_array_elements(id)?;
431 Some((*rows, *cols, elements))
432 }
433 _ => None,
434 }
435 }
436
437 pub fn resolve_string(&self, id: StringId) -> &str {
439 self.strings.resolve(id)
440 }
441
442 pub fn strings(&self) -> &StringInterner {
444 &self.strings
445 }
446
447 pub fn strings_mut(&mut self) -> &mut StringInterner {
449 &mut self.strings
450 }
451
452 pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
453 let hash = {
454 let mut hasher = DefaultHasher::new();
455 specifier.hash(&mut hasher);
456 hasher.finish()
457 };
458
459 if let Some(&id) = self.table_spec_dedup.get(&hash)
460 && self
461 .table_specs
462 .get(id.0 as usize)
463 .is_some_and(|existing| existing == specifier)
464 {
465 return id;
466 }
467
468 let id = TableSpecId(self.table_specs.len() as u32);
469 self.table_specs.push(specifier.clone());
470 self.table_spec_dedup.insert(hash, id);
471 id
472 }
473
474 pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
475 self.table_specs.get(id.0 as usize)
476 }
477
478 fn hash_node(&self, node: &AstNodeData) -> u64 {
480 let mut hasher = DefaultHasher::new();
481 node.hash(&mut hasher);
482 hasher.finish()
483 }
484
485 pub fn stats(&self) -> AstArenaStats {
487 AstArenaStats {
488 node_count: self.nodes.len(),
489 dedup_hits: self.dedup_hits,
490 string_count: self.strings.len(),
491 table_spec_count: self.table_specs.len(),
492 total_args: self.function_args.len(),
493 total_array_elements: self.array_elements.len(),
494 }
495 }
496
497 pub fn memory_usage(&self) -> usize {
499 self.nodes.capacity() * std::mem::size_of::<AstNodeEntry>()
500 + self.dedup_map.capacity() * (8 + 4) + self.function_args.capacity() * 4
502 + self.array_elements.capacity() * 4
503 + self.strings.memory_usage()
504 + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
505 + self.table_spec_dedup.capacity() * (8 + 4)
506 }
507
508 pub fn clear(&mut self) {
510 self.nodes.clear();
511 self.dedup_map.clear();
512 self.function_args.clear();
513 self.array_elements.clear();
514 self.strings.clear();
515 self.table_specs.clear();
516 self.table_spec_dedup.clear();
517 self.dedup_hits = 0;
518 }
519}
520
521impl Default for AstArena {
522 fn default() -> Self {
523 Self::new()
524 }
525}
526
527impl fmt::Debug for AstArena {
528 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
529 f.debug_struct("AstArena")
530 .field("nodes", &self.nodes.len())
531 .field("dedup_hits", &self.dedup_hits)
532 .field("strings", &self.strings.len())
533 .finish()
534 }
535}
536
537#[derive(Debug, Clone)]
539pub struct AstArenaStats {
540 pub node_count: usize,
541 pub dedup_hits: usize,
542 pub string_count: usize,
543 pub table_spec_count: usize,
544 pub total_args: usize,
545 pub total_array_elements: usize,
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn test_ast_arena_literal() {
554 let mut arena = AstArena::new();
555
556 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
557 let lit2 = arena.insert_literal(ValueRef::boolean(true));
558
559 assert_ne!(lit1, lit2);
560
561 match arena.get(lit1) {
562 Some(AstNodeData::Literal(v)) => {
563 assert_eq!(v.as_small_int(), Some(42));
564 }
565 _ => panic!("Expected literal node"),
566 }
567 }
568
569 #[test]
570 fn test_ast_arena_deduplication() {
571 let mut arena = AstArena::new();
572
573 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
575 let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
576
577 assert_eq!(lit1, lit2); assert_eq!(arena.stats().dedup_hits, 1);
579 }
580
581 #[test]
582 fn test_ast_arena_binary_op() {
583 let mut arena = AstArena::new();
584
585 let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
586 let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
587 let add = arena.insert_binary_op("+", left, right);
588
589 match arena.get(add) {
590 Some(AstNodeData::BinaryOp {
591 op_id,
592 left_id,
593 right_id,
594 }) => {
595 assert_eq!(arena.resolve_string(*op_id), "+");
596 assert_eq!(*left_id, left);
597 assert_eq!(*right_id, right);
598 }
599 _ => panic!("Expected binary op node"),
600 }
601 }
602
603 #[test]
604 fn test_ast_arena_function() {
605 let mut arena = AstArena::new();
606
607 let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
608 let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
609 let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
610
611 let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
612
613 match arena.get(func) {
614 Some(AstNodeData::Function {
615 name_id,
616 args_count,
617 ..
618 }) => {
619 assert_eq!(arena.resolve_string(*name_id), "SUM");
620 assert_eq!(*args_count, 3);
621 }
622 _ => panic!("Expected function node"),
623 }
624
625 let args = arena.get_function_args(func).unwrap();
626 assert_eq!(args, &[arg1, arg2, arg3]);
627 }
628
629 #[test]
630 fn test_ast_arena_structural_sharing() {
631 let mut arena = AstArena::new();
632
633 let a1_ref = arena.insert_reference(
635 "A1",
636 CompactRefType::Cell {
637 sheet: None,
638 row: 1,
639 col: 1,
640 row_abs: false,
641 col_abs: false,
642 },
643 );
644
645 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
647 let expr1 = arena.insert_binary_op("+", a1_ref, one);
648
649 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
651 let expr2 = arena.insert_binary_op("*", a1_ref, two);
652
653 assert_eq!(arena.stats().node_count, 5); let a1_ref2 = arena.insert_reference(
658 "A1",
659 CompactRefType::Cell {
660 sheet: None,
661 row: 1,
662 col: 1,
663 row_abs: false,
664 col_abs: false,
665 },
666 );
667 assert_eq!(a1_ref, a1_ref2);
668 }
669
670 #[test]
671 fn test_ast_arena_array() {
672 let mut arena = AstArena::new();
673
674 let elements = vec![
675 arena.insert_literal(ValueRef::small_int(1).unwrap()),
676 arena.insert_literal(ValueRef::small_int(2).unwrap()),
677 arena.insert_literal(ValueRef::small_int(3).unwrap()),
678 arena.insert_literal(ValueRef::small_int(4).unwrap()),
679 ];
680
681 let array = arena.insert_array(2, 2, elements.clone());
682
683 match arena.get(array) {
684 Some(AstNodeData::Array { rows, cols, .. }) => {
685 assert_eq!(*rows, 2);
686 assert_eq!(*cols, 2);
687 }
688 _ => panic!("Expected array node"),
689 }
690
691 let stored_elements = arena.get_array_elements(array).unwrap();
692 assert_eq!(stored_elements, &elements[..]);
693 }
694
695 #[test]
696 fn test_ast_arena_complex_expression() {
697 let mut arena = AstArena::new();
698
699 let range = arena.insert_reference(
703 "A1:A10",
704 CompactRefType::Range {
705 sheet: None,
706 start_row: 1,
707 start_col: 1,
708 end_row: 10,
709 end_col: 1,
710 start_row_abs: false,
711 start_col_abs: false,
712 end_row_abs: false,
713 end_col_abs: false,
714 },
715 );
716
717 let sum = arena.insert_function("SUM", vec![range]);
719
720 let b1 = arena.insert_reference(
722 "B1",
723 CompactRefType::Cell {
724 sheet: None,
725 row: 1,
726 col: 2,
727 row_abs: false,
728 col_abs: false,
729 },
730 );
731
732 let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
734
735 let condition = arena.insert_binary_op(">", b1, zero);
737
738 let c1 = arena.insert_reference(
740 "C1",
741 CompactRefType::Cell {
742 sheet: None,
743 row: 1,
744 col: 3,
745 row_abs: false,
746 col_abs: false,
747 },
748 );
749 let d1 = arena.insert_reference(
750 "D1",
751 CompactRefType::Cell {
752 sheet: None,
753 row: 1,
754 col: 4,
755 row_abs: false,
756 col_abs: false,
757 },
758 );
759
760 let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
762
763 let final_expr = arena.insert_binary_op("+", sum, if_expr);
765
766 assert!(arena.get(final_expr).is_some());
768 assert_eq!(arena.stats().node_count, 9); }
773
774 #[test]
775 fn test_ast_arena_string_deduplication() {
776 let mut arena = AstArena::new();
777
778 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
780 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
781 let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
782
783 let add1 = arena.insert_binary_op("+", one, two);
784 let add2 = arena.insert_binary_op("+", two, three);
785 let add3 = arena.insert_binary_op("+", one, three);
786
787 assert_eq!(arena.strings().len(), 1);
789 }
790
791 #[test]
792 fn test_ast_arena_clear() {
793 let mut arena = AstArena::new();
794
795 arena.insert_literal(ValueRef::small_int(1).unwrap());
796 arena.insert_literal(ValueRef::small_int(2).unwrap());
797 let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
798 let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
799 arena.insert_binary_op("+", left, right);
800
801 assert_eq!(arena.stats().node_count, 5);
802
803 arena.clear();
804
805 assert_eq!(arena.stats().node_count, 0);
806 assert_eq!(arena.strings().len(), 0);
807 }
808}