1use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::{ExternalRefKind, TableSpecifier};
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16 pub fn as_u32(self) -> u32 {
17 self.0
18 }
19}
20
21impl fmt::Display for AstNodeId {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(f, "AstNode({})", self.0)
24 }
25}
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
28pub struct TableSpecId(u32);
29
30impl TableSpecId {
31 pub fn as_u32(self) -> u32 {
32 self.0
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum AstNodeData {
39 Literal(ValueRef),
41
42 Reference {
44 original_id: StringId, ref_type: CompactRefType, },
47
48 UnaryOp { op_id: StringId, expr_id: AstNodeId },
50
51 BinaryOp {
53 op_id: StringId,
54 left_id: AstNodeId,
55 right_id: AstNodeId,
56 },
57
58 Function {
60 name_id: StringId,
61 args_offset: u32, args_count: u16, },
64
65 Array {
67 rows: u16,
68 cols: u16,
69 elements_offset: u32, },
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SheetKey {
76 Id(u16),
77 Name(StringId),
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub enum CompactRefType {
83 Cell {
84 sheet: Option<SheetKey>,
85 row: u32,
86 col: u32,
87 row_abs: bool,
88 col_abs: bool,
89 },
90 Range {
91 sheet: Option<SheetKey>,
92 start_row: u32,
93 start_col: u32,
94 end_row: u32,
95 end_col: u32,
96 start_row_abs: bool,
97 start_col_abs: bool,
98 end_row_abs: bool,
99 end_col_abs: bool,
100 },
101 External {
102 raw_id: StringId,
103 book_id: StringId,
104 sheet_id: StringId,
105 kind: ExternalRefKind,
106 },
107 NamedRange(StringId),
108 Table {
109 name_id: StringId,
110 specifier_id: Option<TableSpecId>,
111 },
112 Cell3D {
114 sheet_first: StringId,
115 sheet_last: StringId,
116 row: u32,
117 col: u32,
118 row_abs: bool,
119 col_abs: bool,
120 },
121 Range3D {
123 sheet_first: StringId,
124 sheet_last: StringId,
125 start_row: u32,
126 start_col: u32,
127 end_row: u32,
128 end_col: u32,
129 start_row_abs: bool,
130 start_col_abs: bool,
131 end_row_abs: bool,
132 end_col_abs: bool,
133 },
134}
135
136pub struct AstArena {
138 nodes: Vec<AstNodeData>,
140
141 dedup_map: FxHashMap<u64, AstNodeId>,
143
144 function_args: Vec<AstNodeId>,
146
147 array_elements: Vec<AstNodeId>,
149
150 strings: StringInterner,
152
153 table_specs: Vec<TableSpecifier>,
155 table_spec_dedup: FxHashMap<u64, TableSpecId>,
156
157 dedup_hits: usize,
159}
160
161impl AstArena {
162 pub fn new() -> Self {
163 Self {
164 nodes: Vec::new(),
165 dedup_map: FxHashMap::default(),
166 function_args: Vec::new(),
167 array_elements: Vec::new(),
168 strings: StringInterner::new(),
169 table_specs: Vec::new(),
170 table_spec_dedup: FxHashMap::default(),
171 dedup_hits: 0,
172 }
173 }
174
175 pub fn with_capacity(node_cap: usize) -> Self {
176 Self {
177 nodes: Vec::with_capacity(node_cap),
178 dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
179 function_args: Vec::with_capacity(node_cap * 2), array_elements: Vec::with_capacity(node_cap),
181 strings: StringInterner::with_capacity(node_cap / 10),
182 table_specs: Vec::new(),
183 table_spec_dedup: FxHashMap::default(),
184 dedup_hits: 0,
185 }
186 }
187
188 pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
190 let hash = self.hash_node(&node);
192
193 if let Some(&id) = self.dedup_map.get(&hash) {
195 if self.nodes[id.0 as usize] == node {
197 self.dedup_hits += 1;
198 return id;
199 }
200 }
201
202 let id = AstNodeId(self.nodes.len() as u32);
204 self.nodes.push(node);
205 self.dedup_map.insert(hash, id);
206 id
207 }
208
209 pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
211 self.insert(AstNodeData::Literal(value))
212 }
213
214 pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
216 let original_id = self.strings.intern(original);
217 self.insert(AstNodeData::Reference {
218 original_id,
219 ref_type,
220 })
221 }
222
223 pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
225 let op_id = self.strings.intern(op);
226 self.insert(AstNodeData::UnaryOp {
227 op_id,
228 expr_id: expr,
229 })
230 }
231
232 pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
234 let op_id = self.strings.intern(op);
235 self.insert(AstNodeData::BinaryOp {
236 op_id,
237 left_id: left,
238 right_id: right,
239 })
240 }
241
242 pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
244 let name_id = self.strings.intern(name);
245 let args_offset = self.function_args.len() as u32;
246 let args_count = args.len() as u16;
247
248 self.function_args.extend(args);
249
250 self.insert(AstNodeData::Function {
251 name_id,
252 args_offset,
253 args_count,
254 })
255 }
256
257 pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
259 assert_eq!(
260 elements.len(),
261 (rows * cols) as usize,
262 "Array dimensions don't match element count"
263 );
264
265 let elements_offset = self.array_elements.len() as u32;
266 self.array_elements.extend(elements);
267
268 self.insert(AstNodeData::Array {
269 rows,
270 cols,
271 elements_offset,
272 })
273 }
274
275 pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
277 self.nodes.get(id.0 as usize)
278 }
279
280 pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
282 match self.get(id)? {
283 AstNodeData::Function {
284 args_offset,
285 args_count,
286 ..
287 } => {
288 let start = *args_offset as usize;
289 let end = start + *args_count as usize;
290 Some(&self.function_args[start..end])
291 }
292 _ => None,
293 }
294 }
295
296 pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
298 match self.get(id)? {
299 AstNodeData::Array {
300 rows,
301 cols,
302 elements_offset,
303 } => {
304 let start = *elements_offset as usize;
305 let count = (*rows * *cols) as usize;
306 let end = start + count;
307 Some(&self.array_elements[start..end])
308 }
309 _ => None,
310 }
311 }
312
313 pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
314 match self.get(id)? {
315 AstNodeData::Array { rows, cols, .. } => {
316 let elements = self.get_array_elements(id)?;
317 Some((*rows, *cols, elements))
318 }
319 _ => None,
320 }
321 }
322
323 pub fn resolve_string(&self, id: StringId) -> &str {
325 self.strings.resolve(id)
326 }
327
328 pub fn strings(&self) -> &StringInterner {
330 &self.strings
331 }
332
333 pub fn strings_mut(&mut self) -> &mut StringInterner {
335 &mut self.strings
336 }
337
338 pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
339 let hash = {
340 let mut hasher = DefaultHasher::new();
341 specifier.hash(&mut hasher);
342 hasher.finish()
343 };
344
345 if let Some(&id) = self.table_spec_dedup.get(&hash)
346 && self
347 .table_specs
348 .get(id.0 as usize)
349 .is_some_and(|existing| existing == specifier)
350 {
351 return id;
352 }
353
354 let id = TableSpecId(self.table_specs.len() as u32);
355 self.table_specs.push(specifier.clone());
356 self.table_spec_dedup.insert(hash, id);
357 id
358 }
359
360 pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
361 self.table_specs.get(id.0 as usize)
362 }
363
364 fn hash_node(&self, node: &AstNodeData) -> u64 {
366 let mut hasher = DefaultHasher::new();
367 node.hash(&mut hasher);
368 hasher.finish()
369 }
370
371 pub fn stats(&self) -> AstArenaStats {
373 AstArenaStats {
374 node_count: self.nodes.len(),
375 dedup_hits: self.dedup_hits,
376 string_count: self.strings.len(),
377 table_spec_count: self.table_specs.len(),
378 total_args: self.function_args.len(),
379 total_array_elements: self.array_elements.len(),
380 }
381 }
382
383 pub fn memory_usage(&self) -> usize {
385 self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
386 + self.dedup_map.capacity() * (8 + 4) + self.function_args.capacity() * 4
388 + self.array_elements.capacity() * 4
389 + self.strings.memory_usage()
390 + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
391 + self.table_spec_dedup.capacity() * (8 + 4)
392 }
393
394 pub fn clear(&mut self) {
396 self.nodes.clear();
397 self.dedup_map.clear();
398 self.function_args.clear();
399 self.array_elements.clear();
400 self.strings.clear();
401 self.table_specs.clear();
402 self.table_spec_dedup.clear();
403 self.dedup_hits = 0;
404 }
405}
406
407impl Default for AstArena {
408 fn default() -> Self {
409 Self::new()
410 }
411}
412
413impl fmt::Debug for AstArena {
414 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415 f.debug_struct("AstArena")
416 .field("nodes", &self.nodes.len())
417 .field("dedup_hits", &self.dedup_hits)
418 .field("strings", &self.strings.len())
419 .finish()
420 }
421}
422
423#[derive(Debug, Clone)]
425pub struct AstArenaStats {
426 pub node_count: usize,
427 pub dedup_hits: usize,
428 pub string_count: usize,
429 pub table_spec_count: usize,
430 pub total_args: usize,
431 pub total_array_elements: usize,
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_ast_arena_literal() {
440 let mut arena = AstArena::new();
441
442 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
443 let lit2 = arena.insert_literal(ValueRef::boolean(true));
444
445 assert_ne!(lit1, lit2);
446
447 match arena.get(lit1) {
448 Some(AstNodeData::Literal(v)) => {
449 assert_eq!(v.as_small_int(), Some(42));
450 }
451 _ => panic!("Expected literal node"),
452 }
453 }
454
455 #[test]
456 fn test_ast_arena_deduplication() {
457 let mut arena = AstArena::new();
458
459 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
461 let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
462
463 assert_eq!(lit1, lit2); assert_eq!(arena.stats().dedup_hits, 1);
465 }
466
467 #[test]
468 fn test_ast_arena_binary_op() {
469 let mut arena = AstArena::new();
470
471 let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
472 let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
473 let add = arena.insert_binary_op("+", left, right);
474
475 match arena.get(add) {
476 Some(AstNodeData::BinaryOp {
477 op_id,
478 left_id,
479 right_id,
480 }) => {
481 assert_eq!(arena.resolve_string(*op_id), "+");
482 assert_eq!(*left_id, left);
483 assert_eq!(*right_id, right);
484 }
485 _ => panic!("Expected binary op node"),
486 }
487 }
488
489 #[test]
490 fn test_ast_arena_function() {
491 let mut arena = AstArena::new();
492
493 let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
494 let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
495 let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
496
497 let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
498
499 match arena.get(func) {
500 Some(AstNodeData::Function {
501 name_id,
502 args_count,
503 ..
504 }) => {
505 assert_eq!(arena.resolve_string(*name_id), "SUM");
506 assert_eq!(*args_count, 3);
507 }
508 _ => panic!("Expected function node"),
509 }
510
511 let args = arena.get_function_args(func).unwrap();
512 assert_eq!(args, &[arg1, arg2, arg3]);
513 }
514
515 #[test]
516 fn test_ast_arena_structural_sharing() {
517 let mut arena = AstArena::new();
518
519 let a1_ref = arena.insert_reference(
521 "A1",
522 CompactRefType::Cell {
523 sheet: None,
524 row: 1,
525 col: 1,
526 row_abs: false,
527 col_abs: false,
528 },
529 );
530
531 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
533 let expr1 = arena.insert_binary_op("+", a1_ref, one);
534
535 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
537 let expr2 = arena.insert_binary_op("*", a1_ref, two);
538
539 assert_eq!(arena.stats().node_count, 5); let a1_ref2 = arena.insert_reference(
544 "A1",
545 CompactRefType::Cell {
546 sheet: None,
547 row: 1,
548 col: 1,
549 row_abs: false,
550 col_abs: false,
551 },
552 );
553 assert_eq!(a1_ref, a1_ref2);
554 }
555
556 #[test]
557 fn test_ast_arena_array() {
558 let mut arena = AstArena::new();
559
560 let elements = vec![
561 arena.insert_literal(ValueRef::small_int(1).unwrap()),
562 arena.insert_literal(ValueRef::small_int(2).unwrap()),
563 arena.insert_literal(ValueRef::small_int(3).unwrap()),
564 arena.insert_literal(ValueRef::small_int(4).unwrap()),
565 ];
566
567 let array = arena.insert_array(2, 2, elements.clone());
568
569 match arena.get(array) {
570 Some(AstNodeData::Array { rows, cols, .. }) => {
571 assert_eq!(*rows, 2);
572 assert_eq!(*cols, 2);
573 }
574 _ => panic!("Expected array node"),
575 }
576
577 let stored_elements = arena.get_array_elements(array).unwrap();
578 assert_eq!(stored_elements, &elements[..]);
579 }
580
581 #[test]
582 fn test_ast_arena_complex_expression() {
583 let mut arena = AstArena::new();
584
585 let range = arena.insert_reference(
589 "A1:A10",
590 CompactRefType::Range {
591 sheet: None,
592 start_row: 1,
593 start_col: 1,
594 end_row: 10,
595 end_col: 1,
596 start_row_abs: false,
597 start_col_abs: false,
598 end_row_abs: false,
599 end_col_abs: false,
600 },
601 );
602
603 let sum = arena.insert_function("SUM", vec![range]);
605
606 let b1 = arena.insert_reference(
608 "B1",
609 CompactRefType::Cell {
610 sheet: None,
611 row: 1,
612 col: 2,
613 row_abs: false,
614 col_abs: false,
615 },
616 );
617
618 let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
620
621 let condition = arena.insert_binary_op(">", b1, zero);
623
624 let c1 = arena.insert_reference(
626 "C1",
627 CompactRefType::Cell {
628 sheet: None,
629 row: 1,
630 col: 3,
631 row_abs: false,
632 col_abs: false,
633 },
634 );
635 let d1 = arena.insert_reference(
636 "D1",
637 CompactRefType::Cell {
638 sheet: None,
639 row: 1,
640 col: 4,
641 row_abs: false,
642 col_abs: false,
643 },
644 );
645
646 let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
648
649 let final_expr = arena.insert_binary_op("+", sum, if_expr);
651
652 assert!(arena.get(final_expr).is_some());
654 assert_eq!(arena.stats().node_count, 9); }
659
660 #[test]
661 fn test_ast_arena_string_deduplication() {
662 let mut arena = AstArena::new();
663
664 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
666 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
667 let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
668
669 let add1 = arena.insert_binary_op("+", one, two);
670 let add2 = arena.insert_binary_op("+", two, three);
671 let add3 = arena.insert_binary_op("+", one, three);
672
673 assert_eq!(arena.strings().len(), 1);
675 }
676
677 #[test]
678 fn test_ast_arena_clear() {
679 let mut arena = AstArena::new();
680
681 arena.insert_literal(ValueRef::small_int(1).unwrap());
682 arena.insert_literal(ValueRef::small_int(2).unwrap());
683 let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
684 let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
685 arena.insert_binary_op("+", left, right);
686
687 assert_eq!(arena.stats().node_count, 5);
688
689 arena.clear();
690
691 assert_eq!(arena.stats().node_count, 0);
692 assert_eq!(arena.strings().len(), 0);
693 }
694}