Skip to main content

firmion_const_eval/
const_eval.rs

1// Const-time evaluation for firmion.
2//
3// The public interface is `evaluate_and_prune()`, which:
4//   1. Walks the immutable AST to evaluate constants, if-conditions, and asserts.
5//   2. Clones the AST and prunes all `if/else` nodes.
6//
7// The caller receives a fully resolved `SymbolTable` and a strictly immutable `Ast`
8// ready for the LayoutDb phase.
9
10// Don't clutter upstream docs.rs for an otherwise private library.
11#![doc(hidden)]
12
13pub mod ireval;
14pub mod linearizer;
15
16use anyhow::bail;
17use diags::depth_guard::{DepthGuard, MAX_RECURSION_DEPTH};
18use diags::{Diags, SourceSpan};
19use indextree::NodeId;
20use parse_int::parse;
21use std::collections::HashMap;
22
23#[allow(unused_imports)]
24use tracing::{debug, trace};
25
26use ast::{Ast, LexToken};
27use ast::astdb::AstDb;
28use ir::{ConstBuiltins, ObjProps, ParameterValue, RegionProps, strip_kmg};
29use ir::symtable::SymbolTable;
30
31// ── Internal error type for const arithmetic ─────────────────────────────────
32
33enum CalcErr {
34    Overflow(String),
35    DivByZero,
36}
37
38// ── Expression Evaluator (Tree Walker) ─────────────────────────────────────────
39
40/// Evaluate an expression subtree natively on the AST.
41pub fn eval_expr_tree(
42    ast: &Ast,
43    nid: NodeId,
44    symbol_table: &mut SymbolTable,
45    diags: &mut Diags,
46) -> Option<ParameterValue> {
47    let _guard = DepthGuard::enter(MAX_RECURSION_DEPTH).or_else(|| {
48        diags.err1(
49            "ERR_115",
50            &format!(
51                "Const expression nesting depth exceeds maximum ({}).",
52                MAX_RECURSION_DEPTH
53            ),
54            ast.get_tinfo(nid).loc.clone(),
55        );
56        None
57    })?;
58
59    let tinfo = ast.get_tinfo(nid);
60    let src_loc = &tinfo.loc;
61
62    match tinfo.tok {
63        // --- Literals ---
64        LexToken::Integer => {
65            let (base, mult) = strip_kmg(tinfo.val);
66            let v: i64 = parse::<i64>(base)
67                .ok()
68                .and_then(|v| v.checked_mul(mult as i64))
69                .ok_or(())
70                .ok()
71                .or_else(|| {
72                    diags.err1(
73                        "ERR_88",
74                        &format!("Malformed integer in const expression: {}", tinfo.val),
75                        src_loc.clone(),
76                    );
77                    None
78                })?;
79            Some(ParameterValue::Integer(v))
80        }
81        LexToken::U64 => {
82            let no_u = tinfo.val.strip_suffix('u').unwrap_or(tinfo.val);
83            let (base, mult) = strip_kmg(no_u);
84            let v: u64 = parse::<u64>(base)
85                .ok()
86                .and_then(|v| v.checked_mul(mult))
87                .ok_or(())
88                .ok()
89                .or_else(|| {
90                    diags.err1(
91                        "ERR_89",
92                        &format!("Malformed U64 in const expression: {}", tinfo.val),
93                        src_loc.clone(),
94                    );
95                    None
96                })?;
97            Some(ParameterValue::U64(v))
98        }
99        LexToken::I64 => {
100            let no_i = tinfo.val.strip_suffix('i').unwrap_or(tinfo.val);
101            let (base, mult) = strip_kmg(no_i);
102            let v: i64 = parse::<i64>(base)
103                .ok()
104                .and_then(|v| v.checked_mul(mult as i64))
105                .ok_or(())
106                .ok()
107                .or_else(|| {
108                    diags.err1(
109                        "ERR_90",
110                        &format!("Malformed I64 in const expression: {}", tinfo.val),
111                        src_loc.clone(),
112                    );
113                    None
114                })?;
115            Some(ParameterValue::I64(v))
116        }
117        LexToken::QuotedString => {
118            let trimmed = tinfo
119                .val
120                .strip_prefix('"')
121                .unwrap_or(tinfo.val)
122                .strip_suffix('"')
123                .unwrap_or(tinfo.val)
124                .to_string();
125            Some(ParameterValue::QuotedString(trimmed))
126        }
127        LexToken::Identifier => {
128            let name = tinfo.val.to_string();
129            if let Some(val) = symbol_table.get_value(&name) {
130                symbol_table.mark_used(&name);
131                Some(val)
132            } else {
133                diags.err1("ERR_86", &format!("Unknown or uninitialized identifier '{}' in const expression. Constants must be defined before use.", name), src_loc.clone());
134                None
135            }
136        }
137
138        // --- Builtins ---
139        LexToken::BuiltinVersionString => Some(ParameterValue::QuotedString(
140            ConstBuiltins::get().firmion_version_string.to_string(),
141        )),
142        LexToken::BuiltinVersionMajor => Some(ParameterValue::U64(
143            ConstBuiltins::get().firmion_version_major,
144        )),
145        LexToken::BuiltinVersionMinor => Some(ParameterValue::U64(
146            ConstBuiltins::get().firmion_version_minor,
147        )),
148        LexToken::BuiltinVersionPatch => Some(ParameterValue::U64(
149            ConstBuiltins::get().firmion_version_patch,
150        )),
151
152        // --- Unary ---
153        LexToken::ToI64 | LexToken::ToU64 => {
154            let child_nid = ast.children(nid).next().unwrap();
155            let val = eval_expr_tree(ast, child_nid, symbol_table, diags)?;
156            match (&val, tinfo.tok) {
157                (ParameterValue::U64(v), LexToken::ToI64) => Some(ParameterValue::I64(*v as i64)),
158                (ParameterValue::I64(_) | ParameterValue::Integer(_), LexToken::ToI64) => {
159                    Some(ParameterValue::I64(val.to_i64()))
160                }
161                (
162                    ParameterValue::U64(_) | ParameterValue::I64(_) | ParameterValue::Integer(_),
163                    LexToken::ToU64,
164                ) => Some(ParameterValue::U64(val.to_u64())),
165                _ => {
166                    diags.err1(
167                        "ERR_87",
168                        &format!(
169                            "Cannot apply '{:?}' to {:?} in a const expression.",
170                            tinfo.tok,
171                            val.data_type()
172                        ),
173                        src_loc.clone(),
174                    );
175                    None
176                }
177            }
178        }
179
180        // --- Binary ---
181        LexToken::Plus
182        | LexToken::Minus
183        | LexToken::Asterisk
184        | LexToken::FSlash
185        | LexToken::Percent
186        | LexToken::Ampersand
187        | LexToken::Pipe
188        | LexToken::DoubleLess
189        | LexToken::DoubleGreater => {
190            let mut it = ast.children(nid);
191            let lhs_val = eval_expr_tree(ast, it.next().unwrap(), symbol_table, diags)?;
192            let rhs_val = eval_expr_tree(ast, it.next().unwrap(), symbol_table, diags)?;
193            apply_binary_op(tinfo.tok, lhs_val, rhs_val, src_loc, diags)
194        }
195
196        // --- Comparisons ---
197        LexToken::DoubleEq
198        | LexToken::NEq
199        | LexToken::GEq
200        | LexToken::LEq
201        | LexToken::Gt
202        | LexToken::Lt => {
203            let mut it = ast.children(nid);
204            let lhs_val = eval_expr_tree(ast, it.next().unwrap(), symbol_table, diags)?;
205            let rhs_val = eval_expr_tree(ast, it.next().unwrap(), symbol_table, diags)?;
206            apply_comparison_op(tinfo.tok, lhs_val, rhs_val, src_loc, diags)
207        }
208
209        // --- Logical ---
210        LexToken::DoubleAmpersand | LexToken::DoublePipe => {
211            let mut it = ast.children(nid);
212            let lhs_val = eval_expr_tree(ast, it.next().unwrap(), symbol_table, diags)?;
213            let rhs_val = eval_expr_tree(ast, it.next().unwrap(), symbol_table, diags)?;
214            let Some(lhs_b) = lhs_val.to_bool() else {
215                diags.err1(
216                    "ERR_114",
217                    "'&&'/'||' operands must be numeric",
218                    src_loc.clone(),
219                );
220                return None;
221            };
222            let Some(rhs_b) = rhs_val.to_bool() else {
223                diags.err1(
224                    "ERR_124",
225                    "'&&'/'||' operands must be numeric",
226                    src_loc.clone(),
227                );
228                return None;
229            };
230            let result = if tinfo.tok == LexToken::DoubleAmpersand {
231                lhs_b && rhs_b
232            } else {
233                lhs_b || rhs_b
234            };
235            Some(ParameterValue::U64(if result { 1 } else { 0 }))
236        }
237
238        // --- Rejected layout ops ---
239        LexToken::Sizeof
240        | LexToken::BuiltinOutputSize
241        | LexToken::BuiltinOutputAddr
242        | LexToken::Addr
243        | LexToken::AddrOffset
244        | LexToken::SecOffset
245        | LexToken::FileOffset
246        | LexToken::ObjAlign
247        | LexToken::ObjLma
248        | LexToken::ObjVma => {
249            diags.err1("ERR_85", &format!("Operation '{:?}' cannot be used in a const expression because it requires engine-time layout or addressing.", tinfo.tok), src_loc.clone());
250            None
251        }
252
253        _ => {
254            diags.err1(
255                "ERR_116",
256                &format!(
257                    "Operation '{:?}' is not supported in a const expression.",
258                    tinfo.tok
259                ),
260                src_loc.clone(),
261            );
262            None
263        }
264    }
265}
266
267// ── Numeric helpers ────────────────────────────────────────────────────────────
268
269fn coerce_numeric_pair(
270    lhs: ParameterValue,
271    rhs: ParameterValue,
272    err_code: &str,
273    src_loc: &SourceSpan,
274    diags: &mut Diags,
275) -> Option<(ParameterValue, ParameterValue)> {
276    use ParameterValue::*;
277    match (&lhs, &rhs) {
278        (U64(_), U64(_))
279        | (I64(_), I64(_))
280        | (Integer(_), Integer(_))
281        | (QuotedString(_), QuotedString(_)) => Some((lhs, rhs)),
282        (U64(_), Integer(v)) => Some((lhs, U64(*v as u64))),
283        (Integer(v), U64(_)) => Some((U64(*v as u64), rhs)),
284        (I64(_), Integer(v)) => Some((lhs, I64(*v))),
285        (Integer(v), I64(_)) => Some((I64(*v), rhs)),
286        _ => {
287            diags.err1(
288                err_code,
289                &format!(
290                    "Type mismatch in const expression: {:?} and {:?}.",
291                    lhs.data_type(),
292                    rhs.data_type()
293                ),
294                src_loc.clone(),
295            );
296            None
297        }
298    }
299}
300
301fn apply_binary_op(
302    tok: LexToken,
303    lhs: ParameterValue,
304    rhs: ParameterValue,
305    src_loc: &SourceSpan,
306    diags: &mut Diags,
307) -> Option<ParameterValue> {
308    use ParameterValue::*;
309    let (lhs, rhs) = coerce_numeric_pair(lhs, rhs, "ERR_91", src_loc, diags)?;
310
311    let emit = |err: CalcErr, diags: &mut Diags| -> Option<ParameterValue> {
312        match err {
313            CalcErr::Overflow(msg) => diags.err1("ERR_93", &msg, src_loc.clone()),
314            CalcErr::DivByZero => diags.err1(
315                "ERR_94",
316                "Division by zero in const expression",
317                src_loc.clone(),
318            ),
319        }
320        None
321    };
322
323    match lhs {
324        U64(a) => {
325            let b = rhs.to_u64();
326            match calc_u64_op(tok, a, b) {
327                Ok(r) => Some(U64(r)),
328                Err(e) => emit(e, diags),
329            }
330        }
331        I64(a) => {
332            let b = rhs.to_i64();
333            match calc_i64_op(tok, a, b) {
334                Ok(r) => Some(I64(r)),
335                Err(e) => emit(e, diags),
336            }
337        }
338        Integer(a) => {
339            let b = rhs.to_i64();
340            match calc_i64_op(tok, a, b) {
341                Ok(r) => Some(Integer(r)),
342                Err(e) => emit(e, diags),
343            }
344        }
345        _ => {
346            diags.err1(
347                "ERR_92",
348                &format!(
349                    "Non-numeric type {:?} in arithmetic const expression.",
350                    lhs.data_type()
351                ),
352                src_loc.clone(),
353            );
354            None
355        }
356    }
357}
358
359fn apply_comparison_op(
360    tok: LexToken,
361    lhs: ParameterValue,
362    rhs: ParameterValue,
363    src_loc: &SourceSpan,
364    diags: &mut Diags,
365) -> Option<ParameterValue> {
366    use ParameterValue::*;
367    if let (QuotedString(a), QuotedString(b)) = (&lhs, &rhs) {
368        let result = match tok {
369            LexToken::DoubleEq => a == b,
370            LexToken::NEq => a != b,
371            _ => {
372                diags.err1(
373                    "ERR_96",
374                    "Ordered comparison (>=, <=) is not supported for strings.",
375                    src_loc.clone(),
376                );
377                return None;
378            }
379        };
380        return Some(U64(if result { 1 } else { 0 }));
381    }
382    let (lhs, rhs) = coerce_numeric_pair(lhs, rhs, "ERR_95", src_loc, diags)?;
383
384    let result = match lhs {
385        U64(a) => {
386            let b = rhs.to_u64();
387            match tok {
388                LexToken::DoubleEq => a == b,
389                LexToken::NEq => a != b,
390                LexToken::GEq => a >= b,
391                LexToken::LEq => a <= b,
392                LexToken::Gt => a > b,
393                LexToken::Lt => a < b,
394                _ => unreachable!(),
395            }
396        }
397        I64(a) | Integer(a) => {
398            let b = rhs.to_i64();
399            match tok {
400                LexToken::DoubleEq => a == b,
401                LexToken::NEq => a != b,
402                LexToken::GEq => a >= b,
403                LexToken::LEq => a <= b,
404                LexToken::Gt => a > b,
405                LexToken::Lt => a < b,
406                _ => unreachable!(),
407            }
408        }
409        _ => unreachable!(),
410    };
411    Some(U64(if result { 1 } else { 0 }))
412}
413
414fn calc_u64_op(tok: LexToken, a: u64, b: u64) -> Result<u64, CalcErr> {
415    match tok {
416        LexToken::Plus => a.checked_add(b).ok_or_else(|| {
417            CalcErr::Overflow(format!("Add expression '{a} + {b}' will overflow type U64"))
418        }),
419        LexToken::Minus => a.checked_sub(b).ok_or_else(|| {
420            CalcErr::Overflow(format!(
421                "Subtract expression '{a} - {b}' will underflow type U64"
422            ))
423        }),
424        LexToken::Asterisk => a.checked_mul(b).ok_or_else(|| {
425            CalcErr::Overflow(format!(
426                "Multiply expression '{a} * {b}' will overflow type U64"
427            ))
428        }),
429        LexToken::FSlash => a.checked_div(b).ok_or(CalcErr::DivByZero),
430        LexToken::Percent => {
431            if b == 0 {
432                Err(CalcErr::DivByZero)
433            } else {
434                Ok(a % b)
435            }
436        }
437        LexToken::Ampersand => Ok(a & b),
438        LexToken::Pipe => Ok(a | b),
439        LexToken::DoubleLess => Ok(a << (b & 63)),
440        LexToken::DoubleGreater => Ok(a >> (b & 63)),
441        _ => Err(CalcErr::Overflow(
442            "Unknown operator in U64 const expression".to_string(),
443        )),
444    }
445}
446
447fn calc_i64_op(tok: LexToken, a: i64, b: i64) -> Result<i64, CalcErr> {
448    match tok {
449        LexToken::Plus => a.checked_add(b).ok_or_else(|| {
450            CalcErr::Overflow(format!("Add expression '{a} + {b}' will overflow type I64"))
451        }),
452        LexToken::Minus => a.checked_sub(b).ok_or_else(|| {
453            CalcErr::Overflow(format!(
454                "Subtract expression '{a} - {b}' will underflow type I64"
455            ))
456        }),
457        LexToken::Asterisk => a.checked_mul(b).ok_or_else(|| {
458            CalcErr::Overflow(format!(
459                "Multiply expression '{a} * {b}' will overflow type I64"
460            ))
461        }),
462        LexToken::FSlash => {
463            if b == 0 {
464                Err(CalcErr::DivByZero)
465            } else {
466                Ok(a / b)
467            }
468        }
469        LexToken::Percent => {
470            if b == 0 {
471                Err(CalcErr::DivByZero)
472            } else {
473                Ok(a % b)
474            }
475        }
476        LexToken::Ampersand => Ok(a & b),
477        LexToken::Pipe => Ok(a | b),
478        LexToken::DoubleLess => Ok(a << (b & 63)),
479        LexToken::DoubleGreater => Ok(a >> (b & 63)),
480        _ => Err(CalcErr::Overflow(
481            "Unknown operator in I64 const expression".to_string(),
482        )),
483    }
484}
485
486// ── Region Evaluator ──────────────────────────────────────────────────────────
487
488pub fn evaluate_regions(
489    diags: &mut Diags,
490    ast: &Ast,
491    ast_db: &AstDb,
492    symbol_table: &mut SymbolTable,
493) -> Option<HashMap<String, RegionProps>> {
494    let mut bindings: HashMap<String, RegionProps> = HashMap::new();
495    let mut ok = true;
496
497    for (name, region) in &ast_db.regions {
498        let mut binding = RegionProps {
499            addr: 0,
500            size: 0,
501            name: name.clone(),
502            default_pad_byte: 0xFF,
503            src_loc: region.src_loc.clone(),
504        };
505        for prop_nid in ast.children(region.nid) {
506            let tinfo = ast.get_tinfo(prop_nid);
507            if tinfo.tok != LexToken::RegionProp {
508                continue;
509            }
510            let prop_name = tinfo.val.to_string();
511            let expr_nid = ast.children(prop_nid).next().unwrap();
512            let expr_loc = ast.get_tinfo(expr_nid).loc.clone();
513
514            match eval_expr_tree(ast, expr_nid, symbol_table, diags) {
515                None => {
516                    ok = false;
517                }
518                Some(val) => {
519                    if !val.is_numeric() {
520                        diags.err1(
521                            "ERR_180",
522                            &format!(
523                                "Region property '{}' must evaluate to a numeric value.",
524                                prop_name
525                            ),
526                            expr_loc,
527                        );
528                        ok = false;
529                        continue;
530                    }
531                    match prop_name.as_str() {
532                        "addr" => binding.addr = val.to_u64(),
533                        "size" => binding.size = val.to_u64(),
534                        "default_pad_byte" => {
535                            let pad_val = val.to_u64();
536                            if pad_val > 255 {
537                                diags.err1(
538                                    "ERR_181",
539                                    "Region property 'default_pad_byte' must be in range 0 to 255.",
540                                    expr_loc,
541                                );
542                                ok = false;
543                                continue;
544                            }
545                            binding.default_pad_byte = pad_val as u8;
546                        }
547                        _ => unreachable!(),
548                    }
549                }
550            }
551        }
552        bindings.insert(name.clone(), binding);
553    }
554
555    for (name, binding) in &bindings {
556        if binding.size > 0 && binding.addr.checked_add(binding.size).is_none() {
557            diags.err1(
558                "ERR_188",
559                &format!(
560                    "Region '{}' addr {:#X} + size {:#X} overflows u64.",
561                    name, binding.addr, binding.size
562                ),
563                binding.src_loc.clone(),
564            );
565            ok = false;
566        }
567    }
568    if ok { Some(bindings) } else { None }
569}
570
571/// Resolves obj block property values against the symbol table.
572/// Returns a map of obj name -> ObjProps, or None on error.
573pub fn evaluate_obj_props(
574    diags: &mut Diags,
575    ast: &Ast,
576    ast_db: &AstDb,
577    symbol_table: &mut SymbolTable,
578) -> Option<HashMap<String, ObjProps>> {
579    let mut props: HashMap<String, ObjProps> = HashMap::new();
580    let mut ok = true;
581
582    for (obj_name, obj_decl) in &ast_db.obj_decls {
583        let file    = resolve_obj_prop(obj_decl.nid, ast, "file",    obj_name, &obj_decl.src_loc, symbol_table, diags);
584        let section = resolve_obj_prop(obj_decl.nid, ast, "section", obj_name, &obj_decl.src_loc, symbol_table, diags);
585        // Optional properties: None signals resolution failure; Some(s) includes the empty string
586        // for absent properties.
587        let file_exclude    = resolve_optional_obj_prop(obj_decl.nid, ast, "file_exclude",    obj_name, &obj_decl.src_loc, symbol_table, diags);
588        let section_exclude = resolve_optional_obj_prop(obj_decl.nid, ast, "section_exclude", obj_name, &obj_decl.src_loc, symbol_table, diags);
589        match (file, section, file_exclude, section_exclude) {
590            (Some(f), Some(s), Some(fe), Some(se)) => {
591                props.insert(obj_name.clone(), ObjProps {
592                    file: f,
593                    name: s,
594                    file_exclude: fe,
595                    section_exclude: se,
596                    src_loc: obj_decl.src_loc.clone(),
597                });
598            }
599            _ => ok = false,
600        }
601    }
602    if ok { Some(props) } else { None }
603}
604
605fn resolve_obj_prop(
606    obj_nid: NodeId,
607    ast: &Ast,
608    prop: &str,
609    obj_name: &str,
610    src_loc: &diags::SourceSpan,
611    symbol_table: &mut SymbolTable,
612    diags: &mut Diags,
613) -> Option<String> {
614    let val_nid = ast
615        .children(obj_nid)
616        .filter_map(|prop_nid| {
617            let tinfo = ast.get_tinfo(prop_nid);
618            if tinfo.tok == LexToken::ObjProp && tinfo.val == prop {
619                ast.children(prop_nid).next()
620            } else {
621                None
622            }
623        })
624        .next()
625        .unwrap(); // parser guarantees both section and file are present
626    resolve_string_value(ast.get_tinfo(val_nid), obj_name, prop, src_loc, symbol_table, diags)
627}
628
629/// Resolves an optional obj property.  Returns `Some("")` if the property is
630/// absent, `Some(s)` if present and resolved, `None` if present but resolution
631/// failed (error already emitted).
632fn resolve_optional_obj_prop(
633    obj_nid: NodeId,
634    ast: &Ast,
635    prop: &str,
636    obj_name: &str,
637    src_loc: &diags::SourceSpan,
638    symbol_table: &mut SymbolTable,
639    diags: &mut Diags,
640) -> Option<String> {
641    let Some(val_nid) = ast
642        .children(obj_nid)
643        .filter_map(|prop_nid| {
644            let tinfo = ast.get_tinfo(prop_nid);
645            if tinfo.tok == LexToken::ObjProp && tinfo.val == prop {
646                ast.children(prop_nid).next()
647            } else {
648                None
649            }
650        })
651        .next()
652    else {
653        return Some(String::new()); // property absent -- no exclusion
654    };
655    resolve_string_value(ast.get_tinfo(val_nid), obj_name, prop, src_loc, symbol_table, diags)
656}
657
658fn resolve_string_value(
659    val_tinfo: &ast::TokenInfo<'_>,
660    obj_name: &str,
661    prop: &str,
662    src_loc: &diags::SourceSpan,
663    symbol_table: &mut SymbolTable,
664    diags: &mut Diags,
665) -> Option<String> {
666    match val_tinfo.tok {
667        LexToken::QuotedString => Some(
668            val_tinfo.val
669                .strip_prefix('"')
670                .unwrap_or("")
671                .strip_suffix('"')
672                .unwrap_or("")
673                .to_string(),
674        ),
675        LexToken::Identifier => {
676            let name = val_tinfo.val;
677            match symbol_table.get_value(name) {
678                Some(ParameterValue::QuotedString(s)) => {
679                    symbol_table.mark_used(name);
680                    Some(s)
681                }
682                Some(_) => {
683                    let m = format!(
684                        "obj '{}': '{}' const '{}' must be a string value",
685                        obj_name, prop, name
686                    );
687                    diags.err1("ERR_228", &m, src_loc.clone());
688                    None
689                }
690                None => {
691                    let m = format!(
692                        "obj '{}': '{}' const '{}' is not defined",
693                        obj_name, prop, name
694                    );
695                    diags.err1("ERR_229", &m, src_loc.clone());
696                    None
697                }
698            }
699        }
700        _ => unreachable!("obj property value must be QuotedString or Identifier"),
701    }
702}
703
704// ── Top-Level AST Walker & Pruner ─────────────────────────────────────────────
705
706fn get_if_branches(ast: &Ast, if_nid: NodeId) -> (Vec<NodeId>, Vec<NodeId>) {
707    let children: Vec<NodeId> = ast.children(if_nid).collect();
708    let then_close_idx = children[2..]
709        .iter()
710        .position(|&n| ast.get_tinfo(n).tok == LexToken::CloseBrace)
711        .map(|i| i + 2)
712        .unwrap();
713    let then_stmts = children[2..then_close_idx].to_vec();
714
715    let else_stmts = if then_close_idx + 1 < children.len() {
716        let after_else_idx = then_close_idx + 2;
717        if after_else_idx >= children.len() {
718            vec![]
719        } else {
720            let after_else_nid = children[after_else_idx];
721            if ast.get_tinfo(after_else_nid).tok == LexToken::If {
722                vec![after_else_nid]
723            } else {
724                let else_close_idx = children[after_else_idx + 1..]
725                    .iter()
726                    .position(|&n| ast.get_tinfo(n).tok == LexToken::CloseBrace)
727                    .map(|i| i + after_else_idx + 1)
728                    .unwrap();
729                children[after_else_idx + 1..else_close_idx].to_vec()
730            }
731        }
732    } else {
733        vec![]
734    };
735    (then_stmts, else_stmts)
736}
737
738fn walk_if_statement(
739    ast: &Ast,
740    if_nid: NodeId,
741    symbol_table: &mut SymbolTable,
742    diags: &mut Diags,
743) -> bool {
744    let mut it = ast.children(if_nid);
745    let cond_nid = it.next().unwrap();
746    let cond_loc = ast.get_tinfo(cond_nid).loc.clone();
747
748    let cond_val = eval_expr_tree(ast, cond_nid, symbol_table, diags);
749    let b = match cond_val.and_then(|v| v.to_bool()) {
750        Some(v) => v,
751        None => {
752            diags.err1(
753                "ERR_112",
754                "if condition must evaluate to a numeric type",
755                cond_loc,
756            );
757            return false;
758        }
759    };
760
761    let (then_stmts, else_stmts) = get_if_branches(ast, if_nid);
762    let taken = if b { then_stmts } else { else_stmts };
763
764    let mut ok = true;
765    for stmt_nid in taken {
766        ok &= evaluate_stmt(ast, stmt_nid, symbol_table, diags);
767    }
768    ok
769}
770
771fn evaluate_stmt(
772    ast: &Ast,
773    nid: NodeId,
774    symbol_table: &mut SymbolTable,
775    diags: &mut Diags,
776) -> bool {
777    let tinfo = ast.get_tinfo(nid);
778    match tinfo.tok {
779        LexToken::Const => {
780            let mut it = ast.children(nid);
781            let name_nid = it.next().unwrap();
782            let name = ast.get_tinfo(name_nid).val.to_string();
783            let second_nid = it.next().unwrap();
784            if ast.get_tinfo(second_nid).tok == LexToken::Eq {
785                let expr_nid = it.next().unwrap();
786                if let Some(val) = eval_expr_tree(ast, expr_nid, symbol_table, diags) {
787                    if !symbol_table.contains_key(&name) {
788                        symbol_table.define(name, val, Some(tinfo.loc.clone()));
789                    }
790                    true
791                } else {
792                    false
793                }
794            } else {
795                symbol_table.declare(name, tinfo.loc.clone());
796                true
797            }
798        }
799        LexToken::Eq => {
800            let mut it = ast.children(nid);
801            let name_nid = it.next().unwrap();
802            let expr_nid = it.next().unwrap();
803            let name = ast.get_tinfo(name_nid).val.to_string();
804            if let Some(val) = eval_expr_tree(ast, expr_nid, symbol_table, diags) {
805                symbol_table.assign(&name, val, &tinfo.loc, diags)
806            } else {
807                false
808            }
809        }
810        LexToken::Print if !diags.noprint => {
811            let mut s = String::new();
812            let mut ok = true;
813            for expr_nid in ast.children(nid) {
814                match eval_expr_tree(ast, expr_nid, symbol_table, diags) {
815                    Some(ParameterValue::QuotedString(v)) => s.push_str(&v),
816                    Some(ParameterValue::U64(v)) => s.push_str(&format!("{:#X}", v)),
817                    Some(ParameterValue::I64(v) | ParameterValue::Integer(v)) => {
818                        s.push_str(&format!("{}", v))
819                    }
820                    Some(_) => {
821                        diags.err1(
822                            "ERR_97",
823                            "Cannot print this value type in a const context",
824                            tinfo.loc.clone(),
825                        );
826                        ok = false;
827                    }
828                    None => {
829                        ok = false;
830                    }
831                }
832            }
833            if ok {
834                print!("{}", s);
835            }
836            ok
837        }
838        LexToken::Assert => {
839            let expr_nid = ast.children(nid).next().unwrap();
840            match eval_expr_tree(ast, expr_nid, symbol_table, diags).and_then(|v| v.to_bool()) {
841                Some(false) => {
842                    diags.err1(
843                        "ERR_98",
844                        "Assert expression failed in if/else body",
845                        tinfo.loc.clone(),
846                    );
847                    false
848                }
849                None => {
850                    diags.err1(
851                        "ERR_113",
852                        "assert condition must evaluate to a numeric type",
853                        tinfo.loc.clone(),
854                    );
855                    false
856                }
857                Some(true) => true,
858            }
859        }
860        LexToken::If => walk_if_statement(ast, nid, symbol_table, diags),
861
862        // All other statements are ignored by the const evaluator.
863        // These include section, region, output and more.
864        _ => true,
865    }
866}
867
868// ── Pruning (AST Rewriter) ────────────────────────────────────────────────────
869
870fn prune_body(
871    pruned: &mut Ast,
872    parent_nid: NodeId,
873    symbol_table: &mut SymbolTable,
874    diags: &mut Diags,
875    keep: fn(LexToken) -> bool,
876) -> anyhow::Result<()> {
877    loop {
878        let children: Vec<_> = pruned.children(parent_nid).collect();
879        let maybe_if = children
880            .iter()
881            .find(|&&nid| pruned.get_tinfo(nid).tok == LexToken::If)
882            .copied();
883        match maybe_if {
884            None => break,
885            Some(if_nid) => prune_if_node(pruned, if_nid, symbol_table, diags, keep)?,
886        }
887    }
888    Ok(())
889}
890
891fn prune_if_node(
892    pruned: &mut Ast,
893    if_nid: NodeId,
894    symbol_table: &mut SymbolTable,
895    diags: &mut Diags,
896    keep: fn(LexToken) -> bool,
897) -> anyhow::Result<()> {
898    let cond_nid = pruned.children(if_nid).next().unwrap();
899    let cond_val = eval_expr_tree(pruned, cond_nid, symbol_table, diags)
900        .and_then(|v| v.to_bool())
901        .unwrap_or(false);
902
903    let (then_stmts, else_stmts) = get_if_branches(pruned, if_nid);
904    let stmts_to_promote = if cond_val { &then_stmts } else { &else_stmts };
905
906    for &stmt_nid in stmts_to_promote {
907        if keep(pruned.get_tinfo(stmt_nid).tok) {
908            stmt_nid.detach(pruned.arena_mut());
909            if_nid.insert_before(stmt_nid, pruned.arena_mut());
910        }
911    }
912    if_nid.detach(pruned.arena_mut());
913    Ok(())
914}
915
916// ── Public Interface ──────────────────────────────────────────────────────────
917
918pub fn evaluate_and_prune<'a>(
919    diags: &mut Diags,
920    ast: &Ast<'a>,
921    ast_db: &AstDb,
922    defines: &HashMap<String, ParameterValue>,
923) -> anyhow::Result<(SymbolTable, Ast<'a>)> {
924    debug!("const_eval::evaluate_and_prune: ENTER");
925    let mut symbol_table = SymbolTable::new();
926    for (k, v) in defines {
927        symbol_table.define(k.clone(), v.clone(), None);
928    }
929
930    let mut ok = true;
931    for &nid in &ast_db.const_statements {
932        ok &= evaluate_stmt(ast, nid, &mut symbol_table, diags);
933    }
934    if !ok {
935        bail!("const_eval lowering failed.");
936    }
937
938    let mut pruned = ast.clone();
939    let root_nid = pruned.root();
940    prune_body(&mut pruned, root_nid, &mut symbol_table, diags, |tok| {
941        matches!(tok, LexToken::Section | LexToken::If)
942    })?;
943
944    let section_nids: Vec<_> = pruned
945        .children(pruned.root())
946        .filter(|&nid| pruned.get_tinfo(nid).tok == LexToken::Section)
947        .collect();
948    for sec_nid in section_nids {
949        prune_body(&mut pruned, sec_nid, &mut symbol_table, diags, |_| true)?;
950    }
951
952    debug!("const_eval::evaluate_and_prune: EXIT");
953    Ok((symbol_table, pruned))
954}