1use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::{ExternalRefKind, TableSpecifier};
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16 pub fn as_u32(self) -> u32 {
17 self.0
18 }
19}
20
21impl fmt::Display for AstNodeId {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(f, "AstNode({})", self.0)
24 }
25}
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
28pub struct TableSpecId(u32);
29
30impl TableSpecId {
31 pub fn as_u32(self) -> u32 {
32 self.0
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum AstNodeData {
39 Literal(ValueRef),
41
42 Reference {
44 original_id: StringId, ref_type: CompactRefType, },
47
48 UnaryOp { op_id: StringId, expr_id: AstNodeId },
50
51 BinaryOp {
53 op_id: StringId,
54 left_id: AstNodeId,
55 right_id: AstNodeId,
56 },
57
58 Function {
60 name_id: StringId,
61 args_offset: u32, args_count: u16, },
64
65 Array {
67 rows: u16,
68 cols: u16,
69 elements_offset: u32, },
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SheetKey {
76 Id(u16),
77 Name(StringId),
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub enum CompactRefType {
83 Cell {
84 sheet: Option<SheetKey>,
85 row: u32,
86 col: u32,
87 row_abs: bool,
88 col_abs: bool,
89 },
90 Range {
91 sheet: Option<SheetKey>,
92 start_row: u32,
93 start_col: u32,
94 end_row: u32,
95 end_col: u32,
96 start_row_abs: bool,
97 start_col_abs: bool,
98 end_row_abs: bool,
99 end_col_abs: bool,
100 },
101 External {
102 raw_id: StringId,
103 book_id: StringId,
104 sheet_id: StringId,
105 kind: ExternalRefKind,
106 },
107 NamedRange(StringId),
108 Table {
109 name_id: StringId,
110 specifier_id: Option<TableSpecId>,
111 },
112}
113
114pub struct AstArena {
116 nodes: Vec<AstNodeData>,
118
119 dedup_map: FxHashMap<u64, AstNodeId>,
121
122 function_args: Vec<AstNodeId>,
124
125 array_elements: Vec<AstNodeId>,
127
128 strings: StringInterner,
130
131 table_specs: Vec<TableSpecifier>,
133 table_spec_dedup: FxHashMap<u64, TableSpecId>,
134
135 dedup_hits: usize,
137}
138
139impl AstArena {
140 pub fn new() -> Self {
141 Self {
142 nodes: Vec::new(),
143 dedup_map: FxHashMap::default(),
144 function_args: Vec::new(),
145 array_elements: Vec::new(),
146 strings: StringInterner::new(),
147 table_specs: Vec::new(),
148 table_spec_dedup: FxHashMap::default(),
149 dedup_hits: 0,
150 }
151 }
152
153 pub fn with_capacity(node_cap: usize) -> Self {
154 Self {
155 nodes: Vec::with_capacity(node_cap),
156 dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
157 function_args: Vec::with_capacity(node_cap * 2), array_elements: Vec::with_capacity(node_cap),
159 strings: StringInterner::with_capacity(node_cap / 10),
160 table_specs: Vec::new(),
161 table_spec_dedup: FxHashMap::default(),
162 dedup_hits: 0,
163 }
164 }
165
166 pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
168 let hash = self.hash_node(&node);
170
171 if let Some(&id) = self.dedup_map.get(&hash) {
173 if self.nodes[id.0 as usize] == node {
175 self.dedup_hits += 1;
176 return id;
177 }
178 }
179
180 let id = AstNodeId(self.nodes.len() as u32);
182 self.nodes.push(node);
183 self.dedup_map.insert(hash, id);
184 id
185 }
186
187 pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
189 self.insert(AstNodeData::Literal(value))
190 }
191
192 pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
194 let original_id = self.strings.intern(original);
195 self.insert(AstNodeData::Reference {
196 original_id,
197 ref_type,
198 })
199 }
200
201 pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
203 let op_id = self.strings.intern(op);
204 self.insert(AstNodeData::UnaryOp {
205 op_id,
206 expr_id: expr,
207 })
208 }
209
210 pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
212 let op_id = self.strings.intern(op);
213 self.insert(AstNodeData::BinaryOp {
214 op_id,
215 left_id: left,
216 right_id: right,
217 })
218 }
219
220 pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
222 let name_id = self.strings.intern(name);
223 let args_offset = self.function_args.len() as u32;
224 let args_count = args.len() as u16;
225
226 self.function_args.extend(args);
227
228 self.insert(AstNodeData::Function {
229 name_id,
230 args_offset,
231 args_count,
232 })
233 }
234
235 pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
237 assert_eq!(
238 elements.len(),
239 (rows * cols) as usize,
240 "Array dimensions don't match element count"
241 );
242
243 let elements_offset = self.array_elements.len() as u32;
244 self.array_elements.extend(elements);
245
246 self.insert(AstNodeData::Array {
247 rows,
248 cols,
249 elements_offset,
250 })
251 }
252
253 pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
255 self.nodes.get(id.0 as usize)
256 }
257
258 pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
260 match self.get(id)? {
261 AstNodeData::Function {
262 args_offset,
263 args_count,
264 ..
265 } => {
266 let start = *args_offset as usize;
267 let end = start + *args_count as usize;
268 Some(&self.function_args[start..end])
269 }
270 _ => None,
271 }
272 }
273
274 pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
276 match self.get(id)? {
277 AstNodeData::Array {
278 rows,
279 cols,
280 elements_offset,
281 } => {
282 let start = *elements_offset as usize;
283 let count = (*rows * *cols) as usize;
284 let end = start + count;
285 Some(&self.array_elements[start..end])
286 }
287 _ => None,
288 }
289 }
290
291 pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
292 match self.get(id)? {
293 AstNodeData::Array { rows, cols, .. } => {
294 let elements = self.get_array_elements(id)?;
295 Some((*rows, *cols, elements))
296 }
297 _ => None,
298 }
299 }
300
301 pub fn resolve_string(&self, id: StringId) -> &str {
303 self.strings.resolve(id)
304 }
305
306 pub fn strings(&self) -> &StringInterner {
308 &self.strings
309 }
310
311 pub fn strings_mut(&mut self) -> &mut StringInterner {
313 &mut self.strings
314 }
315
316 pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
317 let hash = {
318 let mut hasher = DefaultHasher::new();
319 specifier.hash(&mut hasher);
320 hasher.finish()
321 };
322
323 if let Some(&id) = self.table_spec_dedup.get(&hash)
324 && self
325 .table_specs
326 .get(id.0 as usize)
327 .is_some_and(|existing| existing == specifier)
328 {
329 return id;
330 }
331
332 let id = TableSpecId(self.table_specs.len() as u32);
333 self.table_specs.push(specifier.clone());
334 self.table_spec_dedup.insert(hash, id);
335 id
336 }
337
338 pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
339 self.table_specs.get(id.0 as usize)
340 }
341
342 fn hash_node(&self, node: &AstNodeData) -> u64 {
344 let mut hasher = DefaultHasher::new();
345 node.hash(&mut hasher);
346 hasher.finish()
347 }
348
349 pub fn stats(&self) -> AstArenaStats {
351 AstArenaStats {
352 node_count: self.nodes.len(),
353 dedup_hits: self.dedup_hits,
354 string_count: self.strings.len(),
355 table_spec_count: self.table_specs.len(),
356 total_args: self.function_args.len(),
357 total_array_elements: self.array_elements.len(),
358 }
359 }
360
361 pub fn memory_usage(&self) -> usize {
363 self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
364 + self.dedup_map.capacity() * (8 + 4) + self.function_args.capacity() * 4
366 + self.array_elements.capacity() * 4
367 + self.strings.memory_usage()
368 + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
369 + self.table_spec_dedup.capacity() * (8 + 4)
370 }
371
372 pub fn clear(&mut self) {
374 self.nodes.clear();
375 self.dedup_map.clear();
376 self.function_args.clear();
377 self.array_elements.clear();
378 self.strings.clear();
379 self.table_specs.clear();
380 self.table_spec_dedup.clear();
381 self.dedup_hits = 0;
382 }
383}
384
385impl Default for AstArena {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl fmt::Debug for AstArena {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 f.debug_struct("AstArena")
394 .field("nodes", &self.nodes.len())
395 .field("dedup_hits", &self.dedup_hits)
396 .field("strings", &self.strings.len())
397 .finish()
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct AstArenaStats {
404 pub node_count: usize,
405 pub dedup_hits: usize,
406 pub string_count: usize,
407 pub table_spec_count: usize,
408 pub total_args: usize,
409 pub total_array_elements: usize,
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_ast_arena_literal() {
418 let mut arena = AstArena::new();
419
420 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
421 let lit2 = arena.insert_literal(ValueRef::boolean(true));
422
423 assert_ne!(lit1, lit2);
424
425 match arena.get(lit1) {
426 Some(AstNodeData::Literal(v)) => {
427 assert_eq!(v.as_small_int(), Some(42));
428 }
429 _ => panic!("Expected literal node"),
430 }
431 }
432
433 #[test]
434 fn test_ast_arena_deduplication() {
435 let mut arena = AstArena::new();
436
437 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
439 let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
440
441 assert_eq!(lit1, lit2); assert_eq!(arena.stats().dedup_hits, 1);
443 }
444
445 #[test]
446 fn test_ast_arena_binary_op() {
447 let mut arena = AstArena::new();
448
449 let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
450 let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
451 let add = arena.insert_binary_op("+", left, right);
452
453 match arena.get(add) {
454 Some(AstNodeData::BinaryOp {
455 op_id,
456 left_id,
457 right_id,
458 }) => {
459 assert_eq!(arena.resolve_string(*op_id), "+");
460 assert_eq!(*left_id, left);
461 assert_eq!(*right_id, right);
462 }
463 _ => panic!("Expected binary op node"),
464 }
465 }
466
467 #[test]
468 fn test_ast_arena_function() {
469 let mut arena = AstArena::new();
470
471 let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
472 let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
473 let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
474
475 let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
476
477 match arena.get(func) {
478 Some(AstNodeData::Function {
479 name_id,
480 args_count,
481 ..
482 }) => {
483 assert_eq!(arena.resolve_string(*name_id), "SUM");
484 assert_eq!(*args_count, 3);
485 }
486 _ => panic!("Expected function node"),
487 }
488
489 let args = arena.get_function_args(func).unwrap();
490 assert_eq!(args, &[arg1, arg2, arg3]);
491 }
492
493 #[test]
494 fn test_ast_arena_structural_sharing() {
495 let mut arena = AstArena::new();
496
497 let a1_ref = arena.insert_reference(
499 "A1",
500 CompactRefType::Cell {
501 sheet: None,
502 row: 1,
503 col: 1,
504 row_abs: false,
505 col_abs: false,
506 },
507 );
508
509 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
511 let expr1 = arena.insert_binary_op("+", a1_ref, one);
512
513 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
515 let expr2 = arena.insert_binary_op("*", a1_ref, two);
516
517 assert_eq!(arena.stats().node_count, 5); let a1_ref2 = arena.insert_reference(
522 "A1",
523 CompactRefType::Cell {
524 sheet: None,
525 row: 1,
526 col: 1,
527 row_abs: false,
528 col_abs: false,
529 },
530 );
531 assert_eq!(a1_ref, a1_ref2);
532 }
533
534 #[test]
535 fn test_ast_arena_array() {
536 let mut arena = AstArena::new();
537
538 let elements = vec![
539 arena.insert_literal(ValueRef::small_int(1).unwrap()),
540 arena.insert_literal(ValueRef::small_int(2).unwrap()),
541 arena.insert_literal(ValueRef::small_int(3).unwrap()),
542 arena.insert_literal(ValueRef::small_int(4).unwrap()),
543 ];
544
545 let array = arena.insert_array(2, 2, elements.clone());
546
547 match arena.get(array) {
548 Some(AstNodeData::Array { rows, cols, .. }) => {
549 assert_eq!(*rows, 2);
550 assert_eq!(*cols, 2);
551 }
552 _ => panic!("Expected array node"),
553 }
554
555 let stored_elements = arena.get_array_elements(array).unwrap();
556 assert_eq!(stored_elements, &elements[..]);
557 }
558
559 #[test]
560 fn test_ast_arena_complex_expression() {
561 let mut arena = AstArena::new();
562
563 let range = arena.insert_reference(
567 "A1:A10",
568 CompactRefType::Range {
569 sheet: None,
570 start_row: 1,
571 start_col: 1,
572 end_row: 10,
573 end_col: 1,
574 start_row_abs: false,
575 start_col_abs: false,
576 end_row_abs: false,
577 end_col_abs: false,
578 },
579 );
580
581 let sum = arena.insert_function("SUM", vec![range]);
583
584 let b1 = arena.insert_reference(
586 "B1",
587 CompactRefType::Cell {
588 sheet: None,
589 row: 1,
590 col: 2,
591 row_abs: false,
592 col_abs: false,
593 },
594 );
595
596 let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
598
599 let condition = arena.insert_binary_op(">", b1, zero);
601
602 let c1 = arena.insert_reference(
604 "C1",
605 CompactRefType::Cell {
606 sheet: None,
607 row: 1,
608 col: 3,
609 row_abs: false,
610 col_abs: false,
611 },
612 );
613 let d1 = arena.insert_reference(
614 "D1",
615 CompactRefType::Cell {
616 sheet: None,
617 row: 1,
618 col: 4,
619 row_abs: false,
620 col_abs: false,
621 },
622 );
623
624 let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
626
627 let final_expr = arena.insert_binary_op("+", sum, if_expr);
629
630 assert!(arena.get(final_expr).is_some());
632 assert_eq!(arena.stats().node_count, 9); }
637
638 #[test]
639 fn test_ast_arena_string_deduplication() {
640 let mut arena = AstArena::new();
641
642 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
644 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
645 let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
646
647 let add1 = arena.insert_binary_op("+", one, two);
648 let add2 = arena.insert_binary_op("+", two, three);
649 let add3 = arena.insert_binary_op("+", one, three);
650
651 assert_eq!(arena.strings().len(), 1);
653 }
654
655 #[test]
656 fn test_ast_arena_clear() {
657 let mut arena = AstArena::new();
658
659 arena.insert_literal(ValueRef::small_int(1).unwrap());
660 arena.insert_literal(ValueRef::small_int(2).unwrap());
661 let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
662 let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
663 arena.insert_binary_op("+", left, right);
664
665 assert_eq!(arena.stats().node_count, 5);
666
667 arena.clear();
668
669 assert_eq!(arena.stats().node_count, 0);
670 assert_eq!(arena.strings().len(), 0);
671 }
672}