1use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use formualizer_parse::parser::TableSpecifier;
6use rustc_hash::FxHashMap;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
13pub struct AstNodeId(u32);
14
15impl AstNodeId {
16 pub fn as_u32(self) -> u32 {
17 self.0
18 }
19}
20
21impl fmt::Display for AstNodeId {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(f, "AstNode({})", self.0)
24 }
25}
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
28pub struct TableSpecId(u32);
29
30impl TableSpecId {
31 pub fn as_u32(self) -> u32 {
32 self.0
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum AstNodeData {
39 Literal(ValueRef),
41
42 Reference {
44 original_id: StringId, ref_type: CompactRefType, },
47
48 UnaryOp { op_id: StringId, expr_id: AstNodeId },
50
51 BinaryOp {
53 op_id: StringId,
54 left_id: AstNodeId,
55 right_id: AstNodeId,
56 },
57
58 Function {
60 name_id: StringId,
61 args_offset: u32, args_count: u16, },
64
65 Array {
67 rows: u16,
68 cols: u16,
69 elements_offset: u32, },
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
75pub enum SheetKey {
76 Id(u16),
77 Name(StringId),
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
82pub enum CompactRefType {
83 Cell {
84 sheet: Option<SheetKey>,
85 row: u32,
86 col: u32,
87 },
88 Range {
89 sheet: Option<SheetKey>,
90 start_row: u32,
91 start_col: u32,
92 end_row: u32,
93 end_col: u32,
94 },
95 NamedRange(StringId),
96 Table {
97 name_id: StringId,
98 specifier_id: Option<TableSpecId>,
99 },
100}
101
102pub struct AstArena {
104 nodes: Vec<AstNodeData>,
106
107 dedup_map: FxHashMap<u64, AstNodeId>,
109
110 function_args: Vec<AstNodeId>,
112
113 array_elements: Vec<AstNodeId>,
115
116 strings: StringInterner,
118
119 table_specs: Vec<TableSpecifier>,
121 table_spec_dedup: FxHashMap<u64, TableSpecId>,
122
123 dedup_hits: usize,
125}
126
127impl AstArena {
128 pub fn new() -> Self {
129 Self {
130 nodes: Vec::new(),
131 dedup_map: FxHashMap::default(),
132 function_args: Vec::new(),
133 array_elements: Vec::new(),
134 strings: StringInterner::new(),
135 table_specs: Vec::new(),
136 table_spec_dedup: FxHashMap::default(),
137 dedup_hits: 0,
138 }
139 }
140
141 pub fn with_capacity(node_cap: usize) -> Self {
142 Self {
143 nodes: Vec::with_capacity(node_cap),
144 dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
145 function_args: Vec::with_capacity(node_cap * 2), array_elements: Vec::with_capacity(node_cap),
147 strings: StringInterner::with_capacity(node_cap / 10),
148 table_specs: Vec::new(),
149 table_spec_dedup: FxHashMap::default(),
150 dedup_hits: 0,
151 }
152 }
153
154 pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
156 let hash = self.hash_node(&node);
158
159 if let Some(&id) = self.dedup_map.get(&hash) {
161 if self.nodes[id.0 as usize] == node {
163 self.dedup_hits += 1;
164 return id;
165 }
166 }
167
168 let id = AstNodeId(self.nodes.len() as u32);
170 self.nodes.push(node);
171 self.dedup_map.insert(hash, id);
172 id
173 }
174
175 pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
177 self.insert(AstNodeData::Literal(value))
178 }
179
180 pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
182 let original_id = self.strings.intern(original);
183 self.insert(AstNodeData::Reference {
184 original_id,
185 ref_type,
186 })
187 }
188
189 pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
191 let op_id = self.strings.intern(op);
192 self.insert(AstNodeData::UnaryOp {
193 op_id,
194 expr_id: expr,
195 })
196 }
197
198 pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
200 let op_id = self.strings.intern(op);
201 self.insert(AstNodeData::BinaryOp {
202 op_id,
203 left_id: left,
204 right_id: right,
205 })
206 }
207
208 pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
210 let name_id = self.strings.intern(name);
211 let args_offset = self.function_args.len() as u32;
212 let args_count = args.len() as u16;
213
214 self.function_args.extend(args);
215
216 self.insert(AstNodeData::Function {
217 name_id,
218 args_offset,
219 args_count,
220 })
221 }
222
223 pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
225 assert_eq!(
226 elements.len(),
227 (rows * cols) as usize,
228 "Array dimensions don't match element count"
229 );
230
231 let elements_offset = self.array_elements.len() as u32;
232 self.array_elements.extend(elements);
233
234 self.insert(AstNodeData::Array {
235 rows,
236 cols,
237 elements_offset,
238 })
239 }
240
241 pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
243 self.nodes.get(id.0 as usize)
244 }
245
246 pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
248 match self.get(id)? {
249 AstNodeData::Function {
250 args_offset,
251 args_count,
252 ..
253 } => {
254 let start = *args_offset as usize;
255 let end = start + *args_count as usize;
256 Some(&self.function_args[start..end])
257 }
258 _ => None,
259 }
260 }
261
262 pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
264 match self.get(id)? {
265 AstNodeData::Array {
266 rows,
267 cols,
268 elements_offset,
269 } => {
270 let start = *elements_offset as usize;
271 let count = (*rows * *cols) as usize;
272 let end = start + count;
273 Some(&self.array_elements[start..end])
274 }
275 _ => None,
276 }
277 }
278
279 pub fn get_array_elements_info(&self, id: AstNodeId) -> Option<(u16, u16, &[AstNodeId])> {
280 match self.get(id)? {
281 AstNodeData::Array { rows, cols, .. } => {
282 let elements = self.get_array_elements(id)?;
283 Some((*rows, *cols, elements))
284 }
285 _ => None,
286 }
287 }
288
289 pub fn resolve_string(&self, id: StringId) -> &str {
291 self.strings.resolve(id)
292 }
293
294 pub fn strings(&self) -> &StringInterner {
296 &self.strings
297 }
298
299 pub fn strings_mut(&mut self) -> &mut StringInterner {
301 &mut self.strings
302 }
303
304 pub fn intern_table_specifier(&mut self, specifier: &TableSpecifier) -> TableSpecId {
305 let hash = {
306 let mut hasher = DefaultHasher::new();
307 specifier.hash(&mut hasher);
308 hasher.finish()
309 };
310
311 if let Some(&id) = self.table_spec_dedup.get(&hash)
312 && self
313 .table_specs
314 .get(id.0 as usize)
315 .is_some_and(|existing| existing == specifier)
316 {
317 return id;
318 }
319
320 let id = TableSpecId(self.table_specs.len() as u32);
321 self.table_specs.push(specifier.clone());
322 self.table_spec_dedup.insert(hash, id);
323 id
324 }
325
326 pub fn resolve_table_specifier(&self, id: TableSpecId) -> Option<&TableSpecifier> {
327 self.table_specs.get(id.0 as usize)
328 }
329
330 fn hash_node(&self, node: &AstNodeData) -> u64 {
332 let mut hasher = DefaultHasher::new();
333 node.hash(&mut hasher);
334 hasher.finish()
335 }
336
337 pub fn stats(&self) -> AstArenaStats {
339 AstArenaStats {
340 node_count: self.nodes.len(),
341 dedup_hits: self.dedup_hits,
342 string_count: self.strings.len(),
343 table_spec_count: self.table_specs.len(),
344 total_args: self.function_args.len(),
345 total_array_elements: self.array_elements.len(),
346 }
347 }
348
349 pub fn memory_usage(&self) -> usize {
351 self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
352 + self.dedup_map.capacity() * (8 + 4) + self.function_args.capacity() * 4
354 + self.array_elements.capacity() * 4
355 + self.strings.memory_usage()
356 + self.table_specs.capacity() * std::mem::size_of::<TableSpecifier>()
357 + self.table_spec_dedup.capacity() * (8 + 4)
358 }
359
360 pub fn clear(&mut self) {
362 self.nodes.clear();
363 self.dedup_map.clear();
364 self.function_args.clear();
365 self.array_elements.clear();
366 self.strings.clear();
367 self.table_specs.clear();
368 self.table_spec_dedup.clear();
369 self.dedup_hits = 0;
370 }
371}
372
373impl Default for AstArena {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379impl fmt::Debug for AstArena {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 f.debug_struct("AstArena")
382 .field("nodes", &self.nodes.len())
383 .field("dedup_hits", &self.dedup_hits)
384 .field("strings", &self.strings.len())
385 .finish()
386 }
387}
388
389#[derive(Debug, Clone)]
391pub struct AstArenaStats {
392 pub node_count: usize,
393 pub dedup_hits: usize,
394 pub string_count: usize,
395 pub table_spec_count: usize,
396 pub total_args: usize,
397 pub total_array_elements: usize,
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_ast_arena_literal() {
406 let mut arena = AstArena::new();
407
408 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
409 let lit2 = arena.insert_literal(ValueRef::boolean(true));
410
411 assert_ne!(lit1, lit2);
412
413 match arena.get(lit1) {
414 Some(AstNodeData::Literal(v)) => {
415 assert_eq!(v.as_small_int(), Some(42));
416 }
417 _ => panic!("Expected literal node"),
418 }
419 }
420
421 #[test]
422 fn test_ast_arena_deduplication() {
423 let mut arena = AstArena::new();
424
425 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
427 let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
428
429 assert_eq!(lit1, lit2); assert_eq!(arena.stats().dedup_hits, 1);
431 }
432
433 #[test]
434 fn test_ast_arena_binary_op() {
435 let mut arena = AstArena::new();
436
437 let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
438 let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
439 let add = arena.insert_binary_op("+", left, right);
440
441 match arena.get(add) {
442 Some(AstNodeData::BinaryOp {
443 op_id,
444 left_id,
445 right_id,
446 }) => {
447 assert_eq!(arena.resolve_string(*op_id), "+");
448 assert_eq!(*left_id, left);
449 assert_eq!(*right_id, right);
450 }
451 _ => panic!("Expected binary op node"),
452 }
453 }
454
455 #[test]
456 fn test_ast_arena_function() {
457 let mut arena = AstArena::new();
458
459 let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
460 let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
461 let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
462
463 let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
464
465 match arena.get(func) {
466 Some(AstNodeData::Function {
467 name_id,
468 args_count,
469 ..
470 }) => {
471 assert_eq!(arena.resolve_string(*name_id), "SUM");
472 assert_eq!(*args_count, 3);
473 }
474 _ => panic!("Expected function node"),
475 }
476
477 let args = arena.get_function_args(func).unwrap();
478 assert_eq!(args, &[arg1, arg2, arg3]);
479 }
480
481 #[test]
482 fn test_ast_arena_structural_sharing() {
483 let mut arena = AstArena::new();
484
485 let a1_ref = arena.insert_reference(
487 "A1",
488 CompactRefType::Cell {
489 sheet: None,
490 row: 1,
491 col: 1,
492 },
493 );
494
495 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
497 let expr1 = arena.insert_binary_op("+", a1_ref, one);
498
499 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
501 let expr2 = arena.insert_binary_op("*", a1_ref, two);
502
503 assert_eq!(arena.stats().node_count, 5); let a1_ref2 = arena.insert_reference(
508 "A1",
509 CompactRefType::Cell {
510 sheet: None,
511 row: 1,
512 col: 1,
513 },
514 );
515 assert_eq!(a1_ref, a1_ref2);
516 }
517
518 #[test]
519 fn test_ast_arena_array() {
520 let mut arena = AstArena::new();
521
522 let elements = vec![
523 arena.insert_literal(ValueRef::small_int(1).unwrap()),
524 arena.insert_literal(ValueRef::small_int(2).unwrap()),
525 arena.insert_literal(ValueRef::small_int(3).unwrap()),
526 arena.insert_literal(ValueRef::small_int(4).unwrap()),
527 ];
528
529 let array = arena.insert_array(2, 2, elements.clone());
530
531 match arena.get(array) {
532 Some(AstNodeData::Array { rows, cols, .. }) => {
533 assert_eq!(*rows, 2);
534 assert_eq!(*cols, 2);
535 }
536 _ => panic!("Expected array node"),
537 }
538
539 let stored_elements = arena.get_array_elements(array).unwrap();
540 assert_eq!(stored_elements, &elements[..]);
541 }
542
543 #[test]
544 fn test_ast_arena_complex_expression() {
545 let mut arena = AstArena::new();
546
547 let range = arena.insert_reference(
551 "A1:A10",
552 CompactRefType::Range {
553 sheet: None,
554 start_row: 1,
555 start_col: 1,
556 end_row: 10,
557 end_col: 1,
558 },
559 );
560
561 let sum = arena.insert_function("SUM", vec![range]);
563
564 let b1 = arena.insert_reference(
566 "B1",
567 CompactRefType::Cell {
568 sheet: None,
569 row: 1,
570 col: 2,
571 },
572 );
573
574 let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
576
577 let condition = arena.insert_binary_op(">", b1, zero);
579
580 let c1 = arena.insert_reference(
582 "C1",
583 CompactRefType::Cell {
584 sheet: None,
585 row: 1,
586 col: 3,
587 },
588 );
589 let d1 = arena.insert_reference(
590 "D1",
591 CompactRefType::Cell {
592 sheet: None,
593 row: 1,
594 col: 4,
595 },
596 );
597
598 let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
600
601 let final_expr = arena.insert_binary_op("+", sum, if_expr);
603
604 assert!(arena.get(final_expr).is_some());
606 assert_eq!(arena.stats().node_count, 9); }
611
612 #[test]
613 fn test_ast_arena_string_deduplication() {
614 let mut arena = AstArena::new();
615
616 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
618 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
619 let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
620
621 let add1 = arena.insert_binary_op("+", one, two);
622 let add2 = arena.insert_binary_op("+", two, three);
623 let add3 = arena.insert_binary_op("+", one, three);
624
625 assert_eq!(arena.strings().len(), 1);
627 }
628
629 #[test]
630 fn test_ast_arena_clear() {
631 let mut arena = AstArena::new();
632
633 arena.insert_literal(ValueRef::small_int(1).unwrap());
634 arena.insert_literal(ValueRef::small_int(2).unwrap());
635 let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
636 let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
637 arena.insert_binary_op("+", left, right);
638
639 assert_eq!(arena.stats().node_count, 5);
640
641 arena.clear();
642
643 assert_eq!(arena.stats().node_count, 0);
644 assert_eq!(arena.strings().len(), 0);
645 }
646}