1use super::string_interner::{StringId, StringInterner};
4use super::value_ref::ValueRef;
5use rustc_hash::FxHashMap;
6use std::collections::hash_map::DefaultHasher;
7use std::fmt;
8use std::hash::{Hash, Hasher};
9
10#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
12pub struct AstNodeId(u32);
13
14impl AstNodeId {
15 pub fn as_u32(self) -> u32 {
16 self.0
17 }
18}
19
20impl fmt::Display for AstNodeId {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 write!(f, "AstNode({})", self.0)
23 }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub enum AstNodeData {
29 Literal(ValueRef),
31
32 Reference {
34 original_id: StringId, ref_type: CompactRefType, },
37
38 UnaryOp { op_id: StringId, expr_id: AstNodeId },
40
41 BinaryOp {
43 op_id: StringId,
44 left_id: AstNodeId,
45 right_id: AstNodeId,
46 },
47
48 Function {
50 name_id: StringId,
51 args_offset: u32, args_count: u16, },
54
55 Array {
57 rows: u16,
58 cols: u16,
59 elements_offset: u32, },
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub enum SheetKey {
66 Id(u16),
67 Name(StringId),
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
72pub enum CompactRefType {
73 Cell {
74 sheet: Option<SheetKey>,
75 row: u32,
76 col: u32,
77 },
78 Range {
79 sheet: Option<SheetKey>,
80 start_row: u32,
81 start_col: u32,
82 end_row: u32,
83 end_col: u32,
84 },
85 NamedRange(StringId),
86 Table(StringId),
87}
88
89pub struct AstArena {
91 nodes: Vec<AstNodeData>,
93
94 dedup_map: FxHashMap<u64, AstNodeId>,
96
97 function_args: Vec<AstNodeId>,
99
100 array_elements: Vec<AstNodeId>,
102
103 strings: StringInterner,
105
106 dedup_hits: usize,
108}
109
110impl AstArena {
111 pub fn new() -> Self {
112 Self {
113 nodes: Vec::new(),
114 dedup_map: FxHashMap::default(),
115 function_args: Vec::new(),
116 array_elements: Vec::new(),
117 strings: StringInterner::new(),
118 dedup_hits: 0,
119 }
120 }
121
122 pub fn with_capacity(node_cap: usize) -> Self {
123 Self {
124 nodes: Vec::with_capacity(node_cap),
125 dedup_map: FxHashMap::with_capacity_and_hasher(node_cap, Default::default()),
126 function_args: Vec::with_capacity(node_cap * 2), array_elements: Vec::with_capacity(node_cap),
128 strings: StringInterner::with_capacity(node_cap / 10),
129 dedup_hits: 0,
130 }
131 }
132
133 pub fn insert(&mut self, node: AstNodeData) -> AstNodeId {
135 let hash = self.hash_node(&node);
137
138 if let Some(&id) = self.dedup_map.get(&hash) {
140 if self.nodes[id.0 as usize] == node {
142 self.dedup_hits += 1;
143 return id;
144 }
145 }
146
147 let id = AstNodeId(self.nodes.len() as u32);
149 self.nodes.push(node);
150 self.dedup_map.insert(hash, id);
151 id
152 }
153
154 pub fn insert_literal(&mut self, value: ValueRef) -> AstNodeId {
156 self.insert(AstNodeData::Literal(value))
157 }
158
159 pub fn insert_reference(&mut self, original: &str, ref_type: CompactRefType) -> AstNodeId {
161 let original_id = self.strings.intern(original);
162 self.insert(AstNodeData::Reference {
163 original_id,
164 ref_type,
165 })
166 }
167
168 pub fn insert_unary_op(&mut self, op: &str, expr: AstNodeId) -> AstNodeId {
170 let op_id = self.strings.intern(op);
171 self.insert(AstNodeData::UnaryOp {
172 op_id,
173 expr_id: expr,
174 })
175 }
176
177 pub fn insert_binary_op(&mut self, op: &str, left: AstNodeId, right: AstNodeId) -> AstNodeId {
179 let op_id = self.strings.intern(op);
180 self.insert(AstNodeData::BinaryOp {
181 op_id,
182 left_id: left,
183 right_id: right,
184 })
185 }
186
187 pub fn insert_function(&mut self, name: &str, args: Vec<AstNodeId>) -> AstNodeId {
189 let name_id = self.strings.intern(name);
190 let args_offset = self.function_args.len() as u32;
191 let args_count = args.len() as u16;
192
193 self.function_args.extend(args);
194
195 self.insert(AstNodeData::Function {
196 name_id,
197 args_offset,
198 args_count,
199 })
200 }
201
202 pub fn insert_array(&mut self, rows: u16, cols: u16, elements: Vec<AstNodeId>) -> AstNodeId {
204 assert_eq!(
205 elements.len(),
206 (rows * cols) as usize,
207 "Array dimensions don't match element count"
208 );
209
210 let elements_offset = self.array_elements.len() as u32;
211 self.array_elements.extend(elements);
212
213 self.insert(AstNodeData::Array {
214 rows,
215 cols,
216 elements_offset,
217 })
218 }
219
220 pub fn get(&self, id: AstNodeId) -> Option<&AstNodeData> {
222 self.nodes.get(id.0 as usize)
223 }
224
225 pub fn get_function_args(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
227 match self.get(id)? {
228 AstNodeData::Function {
229 args_offset,
230 args_count,
231 ..
232 } => {
233 let start = *args_offset as usize;
234 let end = start + *args_count as usize;
235 Some(&self.function_args[start..end])
236 }
237 _ => None,
238 }
239 }
240
241 pub fn get_array_elements(&self, id: AstNodeId) -> Option<&[AstNodeId]> {
243 match self.get(id)? {
244 AstNodeData::Array {
245 rows,
246 cols,
247 elements_offset,
248 } => {
249 let start = *elements_offset as usize;
250 let count = (*rows * *cols) as usize;
251 let end = start + count;
252 Some(&self.array_elements[start..end])
253 }
254 _ => None,
255 }
256 }
257
258 pub fn resolve_string(&self, id: StringId) -> &str {
260 self.strings.resolve(id)
261 }
262
263 pub fn strings(&self) -> &StringInterner {
265 &self.strings
266 }
267
268 pub fn strings_mut(&mut self) -> &mut StringInterner {
270 &mut self.strings
271 }
272
273 fn hash_node(&self, node: &AstNodeData) -> u64 {
275 let mut hasher = DefaultHasher::new();
276 node.hash(&mut hasher);
277 hasher.finish()
278 }
279
280 pub fn stats(&self) -> AstArenaStats {
282 AstArenaStats {
283 node_count: self.nodes.len(),
284 dedup_hits: self.dedup_hits,
285 string_count: self.strings.len(),
286 total_args: self.function_args.len(),
287 total_array_elements: self.array_elements.len(),
288 }
289 }
290
291 pub fn memory_usage(&self) -> usize {
293 self.nodes.capacity() * std::mem::size_of::<AstNodeData>()
294 + self.dedup_map.capacity() * (8 + 4) + self.function_args.capacity() * 4
296 + self.array_elements.capacity() * 4
297 + self.strings.memory_usage()
298 }
299
300 pub fn clear(&mut self) {
302 self.nodes.clear();
303 self.dedup_map.clear();
304 self.function_args.clear();
305 self.array_elements.clear();
306 self.strings.clear();
307 self.dedup_hits = 0;
308 }
309}
310
311impl Default for AstArena {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317impl fmt::Debug for AstArena {
318 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319 f.debug_struct("AstArena")
320 .field("nodes", &self.nodes.len())
321 .field("dedup_hits", &self.dedup_hits)
322 .field("strings", &self.strings.len())
323 .finish()
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct AstArenaStats {
330 pub node_count: usize,
331 pub dedup_hits: usize,
332 pub string_count: usize,
333 pub total_args: usize,
334 pub total_array_elements: usize,
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_ast_arena_literal() {
343 let mut arena = AstArena::new();
344
345 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
346 let lit2 = arena.insert_literal(ValueRef::boolean(true));
347
348 assert_ne!(lit1, lit2);
349
350 match arena.get(lit1) {
351 Some(AstNodeData::Literal(v)) => {
352 assert_eq!(v.as_small_int(), Some(42));
353 }
354 _ => panic!("Expected literal node"),
355 }
356 }
357
358 #[test]
359 fn test_ast_arena_deduplication() {
360 let mut arena = AstArena::new();
361
362 let lit1 = arena.insert_literal(ValueRef::small_int(42).unwrap());
364 let lit2 = arena.insert_literal(ValueRef::small_int(42).unwrap());
365
366 assert_eq!(lit1, lit2); assert_eq!(arena.stats().dedup_hits, 1);
368 }
369
370 #[test]
371 fn test_ast_arena_binary_op() {
372 let mut arena = AstArena::new();
373
374 let left = arena.insert_literal(ValueRef::small_int(1).unwrap());
375 let right = arena.insert_literal(ValueRef::small_int(2).unwrap());
376 let add = arena.insert_binary_op("+", left, right);
377
378 match arena.get(add) {
379 Some(AstNodeData::BinaryOp {
380 op_id,
381 left_id,
382 right_id,
383 }) => {
384 assert_eq!(arena.resolve_string(*op_id), "+");
385 assert_eq!(*left_id, left);
386 assert_eq!(*right_id, right);
387 }
388 _ => panic!("Expected binary op node"),
389 }
390 }
391
392 #[test]
393 fn test_ast_arena_function() {
394 let mut arena = AstArena::new();
395
396 let arg1 = arena.insert_literal(ValueRef::small_int(10).unwrap());
397 let arg2 = arena.insert_literal(ValueRef::small_int(20).unwrap());
398 let arg3 = arena.insert_literal(ValueRef::small_int(30).unwrap());
399
400 let func = arena.insert_function("SUM", vec![arg1, arg2, arg3]);
401
402 match arena.get(func) {
403 Some(AstNodeData::Function {
404 name_id,
405 args_count,
406 ..
407 }) => {
408 assert_eq!(arena.resolve_string(*name_id), "SUM");
409 assert_eq!(*args_count, 3);
410 }
411 _ => panic!("Expected function node"),
412 }
413
414 let args = arena.get_function_args(func).unwrap();
415 assert_eq!(args, &[arg1, arg2, arg3]);
416 }
417
418 #[test]
419 fn test_ast_arena_structural_sharing() {
420 let mut arena = AstArena::new();
421
422 let a1_ref = arena.insert_reference(
424 "A1",
425 CompactRefType::Cell {
426 sheet: None,
427 row: 1,
428 col: 1,
429 },
430 );
431
432 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
434 let expr1 = arena.insert_binary_op("+", a1_ref, one);
435
436 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
438 let expr2 = arena.insert_binary_op("*", a1_ref, two);
439
440 assert_eq!(arena.stats().node_count, 5); let a1_ref2 = arena.insert_reference(
445 "A1",
446 CompactRefType::Cell {
447 sheet: None,
448 row: 1,
449 col: 1,
450 },
451 );
452 assert_eq!(a1_ref, a1_ref2);
453 }
454
455 #[test]
456 fn test_ast_arena_array() {
457 let mut arena = AstArena::new();
458
459 let elements = vec![
460 arena.insert_literal(ValueRef::small_int(1).unwrap()),
461 arena.insert_literal(ValueRef::small_int(2).unwrap()),
462 arena.insert_literal(ValueRef::small_int(3).unwrap()),
463 arena.insert_literal(ValueRef::small_int(4).unwrap()),
464 ];
465
466 let array = arena.insert_array(2, 2, elements.clone());
467
468 match arena.get(array) {
469 Some(AstNodeData::Array { rows, cols, .. }) => {
470 assert_eq!(*rows, 2);
471 assert_eq!(*cols, 2);
472 }
473 _ => panic!("Expected array node"),
474 }
475
476 let stored_elements = arena.get_array_elements(array).unwrap();
477 assert_eq!(stored_elements, &elements[..]);
478 }
479
480 #[test]
481 fn test_ast_arena_complex_expression() {
482 let mut arena = AstArena::new();
483
484 let range = arena.insert_reference(
488 "A1:A10",
489 CompactRefType::Range {
490 sheet: None,
491 start_row: 1,
492 start_col: 1,
493 end_row: 10,
494 end_col: 1,
495 },
496 );
497
498 let sum = arena.insert_function("SUM", vec![range]);
500
501 let b1 = arena.insert_reference(
503 "B1",
504 CompactRefType::Cell {
505 sheet: None,
506 row: 1,
507 col: 2,
508 },
509 );
510
511 let zero = arena.insert_literal(ValueRef::small_int(0).unwrap());
513
514 let condition = arena.insert_binary_op(">", b1, zero);
516
517 let c1 = arena.insert_reference(
519 "C1",
520 CompactRefType::Cell {
521 sheet: None,
522 row: 1,
523 col: 3,
524 },
525 );
526 let d1 = arena.insert_reference(
527 "D1",
528 CompactRefType::Cell {
529 sheet: None,
530 row: 1,
531 col: 4,
532 },
533 );
534
535 let if_expr = arena.insert_function("IF", vec![condition, c1, d1]);
537
538 let final_expr = arena.insert_binary_op("+", sum, if_expr);
540
541 assert!(arena.get(final_expr).is_some());
543 assert_eq!(arena.stats().node_count, 9); }
548
549 #[test]
550 fn test_ast_arena_string_deduplication() {
551 let mut arena = AstArena::new();
552
553 let one = arena.insert_literal(ValueRef::small_int(1).unwrap());
555 let two = arena.insert_literal(ValueRef::small_int(2).unwrap());
556 let three = arena.insert_literal(ValueRef::small_int(3).unwrap());
557
558 let add1 = arena.insert_binary_op("+", one, two);
559 let add2 = arena.insert_binary_op("+", two, three);
560 let add3 = arena.insert_binary_op("+", one, three);
561
562 assert_eq!(arena.strings().len(), 1);
564 }
565
566 #[test]
567 fn test_ast_arena_clear() {
568 let mut arena = AstArena::new();
569
570 arena.insert_literal(ValueRef::small_int(1).unwrap());
571 arena.insert_literal(ValueRef::small_int(2).unwrap());
572 let left = arena.insert_literal(ValueRef::small_int(3).unwrap());
573 let right = arena.insert_literal(ValueRef::small_int(4).unwrap());
574 arena.insert_binary_op("+", left, right);
575
576 assert_eq!(arena.stats().node_count, 5);
577
578 arena.clear();
579
580 assert_eq!(arena.stats().node_count, 0);
581 assert_eq!(arena.strings().len(), 0);
582 }
583}