Skip to main content

asm_rs/
preprocessor.rs

1//! Preprocessor for assembly source text.
2//!
3//! Handles macro definitions (`.macro`/`.endm`), repeat loops (`.rept`, `.irp`,
4//! `.irpc`), and conditional assembly (`.if`/`.ifdef`/`.ifndef`/`.else`/`.endif`)
5//! before the source reaches the parser.
6//!
7//! The preprocessor operates on raw text, expanding directives in-place so that
8//! the downstream lexer and parser see only ordinary assembly statements.
9
10use alloc::borrow::Cow;
11use alloc::collections::BTreeMap;
12use alloc::format;
13use alloc::string::String;
14use alloc::vec::Vec;
15
16use crate::error::{AsmError, Span};
17
18/// Default maximum macro expansion recursion depth.
19const DEFAULT_MAX_RECURSION_DEPTH: usize = 256;
20
21/// Maximum total iterations across all `.rept`/`.irp`/`.irpc` blocks.
22const DEFAULT_MAX_ITERATION_COUNT: usize = 100_000;
23
24/// Append `body` to `out`, replacing every occurrence of `placeholder` with `value`.
25/// Avoids allocating an intermediate `String` — writes directly into `out`.
26fn replace_single_param(out: &mut String, body: &str, placeholder: &str, value: &str) {
27    let ph_bytes = placeholder.as_bytes();
28    let body_bytes = body.as_bytes();
29    let ph_len = ph_bytes.len();
30    let mut start = 0;
31    while start < body_bytes.len() {
32        if let Some(pos) = body[start..].find(placeholder) {
33            out.push_str(&body[start..start + pos]);
34            out.push_str(value);
35            start += pos + ph_len;
36        } else {
37            out.push_str(&body[start..]);
38            break;
39        }
40    }
41}
42
43/// A macro definition.
44#[derive(Debug, Clone)]
45struct MacroDef {
46    /// Parameter names (without leading `\`).
47    params: Vec<MacroParam>,
48    /// The raw body text (lines between `.macro` and `.endm`).
49    body: String,
50}
51
52/// A macro parameter with optional default value.
53#[derive(Debug, Clone)]
54struct MacroParam {
55    name: String,
56    default: Option<String>,
57    is_vararg: bool,
58}
59
60/// Preprocessor state.
61#[derive(Debug)]
62pub struct Preprocessor {
63    /// Defined macros: name → definition.
64    macros: BTreeMap<String, MacroDef>,
65    /// Defined symbols for `.ifdef`/`.ifndef` (name → value).
66    symbols: BTreeMap<String, i128>,
67    /// Counter for `\@` unique label generation.
68    expansion_counter: usize,
69    /// Current recursion depth for macro expansion.
70    recursion_depth: usize,
71    /// Maximum recursion depth (configurable).
72    max_recursion_depth: usize,
73    /// Maximum total iteration count (configurable).
74    max_iteration_count: usize,
75    /// Total iteration count across all loops (bounds check).
76    iteration_count: usize,
77}
78
79impl Preprocessor {
80    /// Create a new preprocessor.
81    pub fn new() -> Self {
82        Self {
83            macros: BTreeMap::new(),
84            symbols: BTreeMap::new(),
85            expansion_counter: 0,
86            recursion_depth: 0,
87            max_recursion_depth: DEFAULT_MAX_RECURSION_DEPTH,
88            max_iteration_count: DEFAULT_MAX_ITERATION_COUNT,
89            iteration_count: 0,
90        }
91    }
92
93    /// Set the maximum recursion depth for macro expansion.
94    pub fn set_max_recursion_depth(&mut self, depth: usize) {
95        self.max_recursion_depth = depth;
96    }
97
98    /// Set the maximum total iteration count for `.rept`/`.irp`/`.irpc`.
99    pub fn set_max_iterations(&mut self, count: usize) {
100        self.max_iteration_count = count;
101    }
102
103    /// Define a symbol for conditional assembly.
104    pub fn define_symbol(&mut self, name: &str, value: i128) {
105        self.symbols.insert(String::from(name), value);
106    }
107
108    /// Process source text, expanding all preprocessor directives.
109    ///
110    /// Returns the expanded source text ready for lexing/parsing.
111    /// When no preprocessor directives are present and no macros are defined,
112    /// returns a borrowed reference to the original source (zero allocation).
113    ///
114    /// # Errors
115    ///
116    /// Returns `AsmError` on malformed directives, recursion limit, or
117    /// iteration limit exceeded.
118    pub fn process<'a>(&mut self, source: &'a str) -> Result<Cow<'a, str>, AsmError> {
119        // Reset iteration count per process() call so long-lived assemblers
120        // don't accumulate towards the limit across multiple emit() calls.
121        self.iteration_count = 0;
122        if !self.needs_expansion(source) {
123            return Ok(Cow::Borrowed(source));
124        }
125        self.expand_text(source).map(Cow::Owned)
126    }
127
128    /// Check whether the source text requires preprocessing.
129    ///
130    /// Returns `false` when no macros are defined and the source contains
131    /// no preprocessor directives, meaning the text can be passed straight
132    /// to the lexer without any transformation.
133    fn needs_expansion(&self, source: &str) -> bool {
134        // If macros are defined, any line could be an invocation.
135        if !self.macros.is_empty() {
136            return true;
137        }
138        // If symbols are defined, we still don't need expansion unless
139        // the source actually references them via .ifdef/.ifndef.
140        // Scan for directive prefixes.  We look for lines whose first
141        // non-whitespace content matches a preprocessor directive.
142        // This is cheaper than the full expansion: no allocation, no
143        // string building — just a linear scan of the source bytes.
144        for line in source.lines() {
145            let trimmed = line.trim();
146            if trimmed.starts_with(".macro ")
147                || trimmed.starts_with(".macro\t")
148                || trimmed.starts_with(".rept ")
149                || trimmed.starts_with(".rept\t")
150                || trimmed.starts_with(".irp ")
151                || trimmed.starts_with(".irp\t")
152                || trimmed.starts_with(".irpc ")
153                || trimmed.starts_with(".irpc\t")
154                || trimmed.starts_with(".if ")
155                || trimmed.starts_with(".if\t")
156                || trimmed == ".if"
157                || trimmed.starts_with(".ifdef ")
158                || trimmed.starts_with(".ifdef\t")
159                || trimmed.starts_with(".ifndef ")
160                || trimmed.starts_with(".ifndef\t")
161                || trimmed == ".exitm"
162            {
163                return true;
164            }
165        }
166        false
167    }
168
169    /// Core expansion loop — processes one level of text.
170    fn expand_text(&mut self, source: &str) -> Result<String, AsmError> {
171        self.recursion_depth += 1;
172        if self.recursion_depth > self.max_recursion_depth {
173            self.recursion_depth -= 1;
174            return Err(AsmError::ResourceLimitExceeded {
175                resource: String::from("macro recursion depth"),
176                limit: self.max_recursion_depth,
177            });
178        }
179
180        let lines: Vec<&str> = source.lines().collect();
181        let mut output = String::new();
182        let mut i = 0;
183
184        let result = self.expand_text_inner(&lines, &mut output, &mut i);
185        self.recursion_depth -= 1;
186        result?;
187        Ok(output)
188    }
189
190    /// Inner expansion logic (separated to ensure recursion_depth cleanup).
191    fn expand_text_inner(
192        &mut self,
193        lines: &[&str],
194        output: &mut String,
195        i: &mut usize,
196    ) -> Result<(), AsmError> {
197        while *i < lines.len() {
198            let line = lines[*i];
199            let trimmed = line.trim();
200
201            // --- Macro definition ---
202            if trimmed.starts_with(".macro ") || trimmed.starts_with(".macro\t") {
203                let (macro_def, end_idx) = self.parse_macro_def(lines, *i)?;
204                let name = parse_macro_name(trimmed, *i)?;
205                self.macros.insert(name, macro_def);
206                *i = end_idx + 1;
207                continue;
208            }
209
210            // --- .rept ---
211            if trimmed.starts_with(".rept ") || trimmed.starts_with(".rept\t") {
212                let (body, end_idx) = collect_block(lines, *i, ".rept", ".endr")?;
213                let count = parse_rept_count(trimmed, *i)?;
214                let expanded = self.expand_rept(count, &body)?;
215                output.push_str(&expanded);
216                *i = end_idx + 1;
217                continue;
218            }
219
220            // --- .irp ---
221            if trimmed.starts_with(".irp ") || trimmed.starts_with(".irp\t") {
222                let (body, end_idx) = collect_block(lines, *i, ".irp", ".endr")?;
223                let (sym, values) = parse_irp_args(trimmed, *i)?;
224                let expanded = self.expand_irp(&sym, &values, &body)?;
225                output.push_str(&expanded);
226                *i = end_idx + 1;
227                continue;
228            }
229
230            // --- .irpc ---
231            if trimmed.starts_with(".irpc ") || trimmed.starts_with(".irpc\t") {
232                let (body, end_idx) = collect_block(lines, *i, ".irpc", ".endr")?;
233                let (sym, chars) = parse_irpc_args(trimmed, *i)?;
234                let expanded = self.expand_irpc(&sym, &chars, &body)?;
235                output.push_str(&expanded);
236                *i = end_idx + 1;
237                continue;
238            }
239
240            // --- Conditional assembly ---
241            if trimmed.starts_with(".if ")
242                || trimmed.starts_with(".if\t")
243                || trimmed == ".if"
244                || trimmed.starts_with(".ifdef ")
245                || trimmed.starts_with(".ifdef\t")
246                || trimmed.starts_with(".ifndef ")
247                || trimmed.starts_with(".ifndef\t")
248            {
249                let (selected_body, end_idx) = self.process_conditional(lines, *i)?;
250                if !selected_body.is_empty() {
251                    let expanded = self.expand_text(&selected_body)?;
252                    output.push_str(&expanded);
253                }
254                *i = end_idx + 1;
255                continue;
256            }
257
258            // --- .exitm (only meaningful inside macro expansion) ---
259            if trimmed == ".exitm" {
260                if self.recursion_depth <= 1 {
261                    // At the top level, .exitm is meaningless — warn the user.
262                    return Err(AsmError::Syntax {
263                        msg: String::from(".exitm outside of macro expansion"),
264                        span: crate::error::Span::new((*i + 1) as u32, 1, 0, trimmed.len()),
265                    });
266                }
267                // Inside macro expansion, stop expanding this level
268                break;
269            }
270
271            // --- Macro invocation ---
272            if let Some(expanded) = self.try_expand_macro(trimmed)? {
273                // Recursively expand the result
274                let re_expanded = self.expand_text(&expanded)?;
275                output.push_str(&re_expanded);
276                *i += 1;
277                continue;
278            }
279
280            // --- .equ / .set / NAME = expr: track symbols for .ifdef ---
281            if let Some((name, val)) = try_parse_symbol_def(trimmed) {
282                self.symbols.insert(name, val);
283            }
284
285            // Ordinary line — pass through
286            output.push_str(line);
287            output.push('\n');
288            *i += 1;
289        }
290
291        Ok(())
292    }
293
294    /// Parse a `.macro name [params...]` ... `.endm` definition.
295    fn parse_macro_def(&self, lines: &[&str], start: usize) -> Result<(MacroDef, usize), AsmError> {
296        let header = lines[start].trim();
297        let params = parse_macro_params(header)?;
298
299        let mut body_lines = Vec::new();
300        let mut depth = 1usize;
301        let mut i = start + 1;
302
303        while i < lines.len() {
304            let trimmed = lines[i].trim();
305            if trimmed.starts_with(".macro ") || trimmed.starts_with(".macro\t") {
306                depth += 1;
307            } else if trimmed == ".endm" {
308                depth -= 1;
309                if depth == 0 {
310                    let body = body_lines.join("\n");
311                    return Ok((MacroDef { params, body }, i));
312                }
313            }
314            body_lines.push(lines[i]);
315            i += 1;
316        }
317
318        Err(AsmError::Syntax {
319            msg: String::from("unterminated .macro (missing .endm)"),
320            span: line_span(start),
321        })
322    }
323
324    /// Try to expand a line as a macro invocation. Returns `None` if no macro matches.
325    fn try_expand_macro(&mut self, line: &str) -> Result<Option<String>, AsmError> {
326        let trimmed = line.trim();
327        if trimmed.is_empty() || trimmed.starts_with('#') || trimmed.starts_with('.') {
328            return Ok(None);
329        }
330
331        // Extract first word as potential macro name
332        let first_word = trimmed.split_whitespace().next().unwrap_or("");
333
334        // Also check if it ends with ':' (label definition) — skip
335        if first_word.ends_with(':') {
336            // Could be `label: macro_name args` — check remainder
337            let rest = trimmed[first_word.len()..].trim();
338            if rest.is_empty() {
339                return Ok(None);
340            }
341            let macro_name = rest.split_whitespace().next().unwrap_or("");
342            if let Some(def) = self.macros.get(macro_name).cloned() {
343                let args_str = rest[macro_name.len()..].trim();
344                let args = parse_macro_args(args_str);
345                let expanded = self.substitute_macro(&def, &args);
346                // Preserve the label
347                return Ok(Some(format!("{}\n{}", first_word, expanded)));
348            }
349            return Ok(None);
350        }
351
352        if let Some(def) = self.macros.get(first_word).cloned() {
353            let args_str = trimmed[first_word.len()..].trim();
354            let args = parse_macro_args(args_str);
355            let expanded = self.substitute_macro(&def, &args);
356            return Ok(Some(expanded));
357        }
358
359        Ok(None)
360    }
361
362    /// Substitute macro parameters and `\@` counter into body text.
363    ///
364    /// Uses a single-pass scan: walks the body once, and at each `\` checks
365    /// for parameter names or `@`.  This is O(M × log N) where M = body length
366    /// and N = parameter count, versus the prior O(N × M) multi-pass approach.
367    fn substitute_macro(&mut self, def: &MacroDef, args: &[String]) -> String {
368        let counter = self.expansion_counter;
369        self.expansion_counter += 1;
370
371        // Pre-compute replacement strings for each parameter
372        let replacements: Vec<(&str, String)> = def
373            .params
374            .iter()
375            .enumerate()
376            .map(|(idx, param)| {
377                let value = if param.is_vararg {
378                    if idx < args.len() {
379                        args[idx..].join(", ")
380                    } else {
381                        param.default.clone().unwrap_or_default()
382                    }
383                } else if idx < args.len() {
384                    args[idx].clone()
385                } else {
386                    param.default.clone().unwrap_or_default()
387                };
388                (param.name.as_str(), value)
389            })
390            .collect();
391
392        let body = &def.body;
393        let mut result = String::with_capacity(body.len());
394        let bytes = body.as_bytes();
395        let len = bytes.len();
396        let mut i = 0;
397
398        while i < len {
399            if bytes[i] == b'\\' && i + 1 < len {
400                // Check for \@ (unique counter)
401                if bytes[i + 1] == b'@' {
402                    use core::fmt::Write;
403                    let _ = write!(result, "{}", counter);
404                    i += 2;
405                    continue;
406                }
407                // Check for \param_name
408                let rest = &body[i + 1..];
409                let mut matched = false;
410                for &(name, ref value) in &replacements {
411                    if rest.starts_with(name) {
412                        // Ensure we match the full token — the char after the
413                        // name must NOT be alphanumeric or '_' (otherwise
414                        // \foo would partially match \foobar).
415                        let end = name.len();
416                        let boundary = end >= rest.len()
417                            || !rest.as_bytes()[end].is_ascii_alphanumeric()
418                                && rest.as_bytes()[end] != b'_';
419                        if boundary {
420                            result.push_str(value);
421                            i += 1 + name.len();
422                            matched = true;
423                            break;
424                        }
425                    }
426                }
427                if !matched {
428                    result.push('\\');
429                    i += 1;
430                }
431            } else {
432                // Copy one character (handles multi-byte UTF-8)
433                let ch = body[i..].chars().next().unwrap_or('\0');
434                result.push(ch);
435                i += ch.len_utf8();
436            }
437        }
438
439        result
440    }
441
442    /// Expand `.rept count` block.
443    fn expand_rept(&mut self, count: usize, body: &str) -> Result<String, AsmError> {
444        let mut raw = String::new();
445        for _ in 0..count {
446            self.iteration_count += 1;
447            if self.iteration_count > self.max_iteration_count {
448                return Err(AsmError::ResourceLimitExceeded {
449                    resource: String::from("preprocessor iterations"),
450                    limit: self.max_iteration_count,
451                });
452            }
453            raw.push_str(body);
454            raw.push('\n');
455        }
456        // Re-expand to handle nested .rept/.irp/.irpc/macros
457        self.expand_text(&raw)
458    }
459
460    /// Expand `.irp sym, val1, val2, ...` block.
461    fn expand_irp(&mut self, sym: &str, values: &[String], body: &str) -> Result<String, AsmError> {
462        let placeholder = format!("\\{}", sym);
463        let mut raw = String::new();
464        for val in values {
465            self.iteration_count += 1;
466            if self.iteration_count > self.max_iteration_count {
467                return Err(AsmError::ResourceLimitExceeded {
468                    resource: String::from("preprocessor iterations"),
469                    limit: self.max_iteration_count,
470                });
471            }
472            replace_single_param(&mut raw, body, &placeholder, val);
473            raw.push('\n');
474        }
475        self.expand_text(&raw)
476    }
477
478    /// Expand `.irpc sym, string` block.
479    fn expand_irpc(&mut self, sym: &str, chars: &str, body: &str) -> Result<String, AsmError> {
480        let placeholder = format!("\\{}", sym);
481        let mut raw = String::new();
482        let mut ch_buf = [0u8; 4];
483        for ch in chars.chars() {
484            self.iteration_count += 1;
485            if self.iteration_count > self.max_iteration_count {
486                return Err(AsmError::ResourceLimitExceeded {
487                    resource: String::from("preprocessor iterations"),
488                    limit: self.max_iteration_count,
489                });
490            }
491            let ch_str = ch.encode_utf8(&mut ch_buf);
492            replace_single_param(&mut raw, body, &placeholder, ch_str);
493            raw.push('\n');
494        }
495        self.expand_text(&raw)
496    }
497
498    /// Process a conditional block (`.if`/`.ifdef`/`.ifndef`).
499    /// Returns the selected body text and the line index of `.endif`.
500    fn process_conditional(
501        &self,
502        lines: &[&str],
503        start: usize,
504    ) -> Result<(String, usize), AsmError> {
505        let header = lines[start].trim();
506
507        // Determine the initial condition result
508        let condition = evaluate_condition(header, &self.symbols, start)?;
509
510        let mut branches: Vec<(bool, Vec<&str>)> = Vec::new();
511        let mut current_cond = condition;
512        let mut current_body: Vec<&str> = Vec::new();
513        let mut depth = 1usize;
514        let mut i = start + 1;
515
516        while i < lines.len() {
517            let trimmed = lines[i].trim();
518
519            // Nested conditional
520            if trimmed.starts_with(".if ")
521                || trimmed.starts_with(".if\t")
522                || trimmed == ".if"
523                || trimmed.starts_with(".ifdef ")
524                || trimmed.starts_with(".ifdef\t")
525                || trimmed.starts_with(".ifndef ")
526                || trimmed.starts_with(".ifndef\t")
527            {
528                depth += 1;
529                current_body.push(lines[i]);
530                i += 1;
531                continue;
532            }
533
534            if trimmed == ".endif" {
535                depth -= 1;
536                if depth == 0 {
537                    branches.push((current_cond, current_body));
538                    // Select first true branch
539                    for (cond, body) in &branches {
540                        if *cond {
541                            return Ok((body.join("\n"), i));
542                        }
543                    }
544                    return Ok((String::new(), i));
545                }
546                current_body.push(lines[i]);
547                i += 1;
548                continue;
549            }
550
551            if depth == 1
552                && (trimmed == ".else"
553                    || trimmed.starts_with(".elseif ")
554                    || trimmed.starts_with(".elseif\t"))
555            {
556                branches.push((current_cond, core::mem::take(&mut current_body)));
557                if trimmed == ".else" {
558                    // .else is true if no prior branch was taken
559                    current_cond = !branches.iter().any(|(c, _)| *c);
560                } else {
561                    // .elseif expr
562                    let expr_str = trimmed.strip_prefix(".elseif").unwrap().trim();
563                    current_cond = if branches.iter().any(|(c, _)| *c) {
564                        false // A prior branch was already taken
565                    } else {
566                        eval_simple_expr(expr_str, &self.symbols) != 0
567                    };
568                }
569                i += 1;
570                continue;
571            }
572
573            current_body.push(lines[i]);
574            i += 1;
575        }
576
577        Err(AsmError::Syntax {
578            msg: String::from("unterminated conditional (missing .endif)"),
579            span: line_span(start),
580        })
581    }
582}
583
584impl Default for Preprocessor {
585    fn default() -> Self {
586        Self::new()
587    }
588}
589
590// --- Helper functions ---
591
592/// Parse the macro name from `.macro name ...`.
593fn parse_macro_name(header: &str, line: usize) -> Result<String, AsmError> {
594    let rest = header.strip_prefix(".macro").unwrap_or(header).trim_start();
595    let name = rest
596        .split(|c: char| c.is_whitespace() || c == ',')
597        .next()
598        .unwrap_or("");
599    if name.is_empty() {
600        return Err(AsmError::Syntax {
601            msg: String::from(".macro directive requires a name"),
602            span: Span::new((line + 1) as u32, 1, 0, header.len()),
603        });
604    }
605    Ok(String::from(name))
606}
607
608/// Parse macro parameters from `.macro name param1, param2=default, rest:vararg`.
609fn parse_macro_params(header: &str) -> Result<Vec<MacroParam>, AsmError> {
610    let rest = header.strip_prefix(".macro").unwrap_or(header).trim_start();
611
612    // Skip the macro name
613    let after_name = rest
614        .split_once(|c: char| c.is_whitespace() || c == ',')
615        .map(|(_, p)| p.trim_start_matches(',').trim())
616        .unwrap_or("");
617
618    if after_name.is_empty() {
619        return Ok(Vec::new());
620    }
621
622    let mut params = Vec::new();
623    for part in after_name.split(',') {
624        let part = part.trim();
625        if part.is_empty() {
626            continue;
627        }
628        if let Some((name, rest)) = part.split_once(':') {
629            let name = name.trim();
630            let rest = rest.trim();
631            if rest == "vararg" {
632                params.push(MacroParam {
633                    name: String::from(name),
634                    default: None,
635                    is_vararg: true,
636                });
637            } else {
638                params.push(MacroParam {
639                    name: String::from(part),
640                    default: None,
641                    is_vararg: false,
642                });
643            }
644        } else if let Some((name, default)) = part.split_once('=') {
645            params.push(MacroParam {
646                name: String::from(name.trim()),
647                default: Some(String::from(default.trim())),
648                is_vararg: false,
649            });
650        } else {
651            params.push(MacroParam {
652                name: String::from(part),
653                default: None,
654                is_vararg: false,
655            });
656        }
657    }
658    Ok(params)
659}
660
661/// Parse arguments passed to a macro invocation.
662fn parse_macro_args(args_str: &str) -> Vec<String> {
663    if args_str.is_empty() {
664        return Vec::new();
665    }
666    args_str
667        .split(',')
668        .map(|s| String::from(s.trim()))
669        .collect()
670}
671
672/// Parse `.rept count` header.
673fn parse_rept_count(header: &str, line: usize) -> Result<usize, AsmError> {
674    let rest = header.strip_prefix(".rept").unwrap_or(header).trim();
675    rest.parse::<usize>().map_err(|_| AsmError::Syntax {
676        msg: format!("invalid .rept count: '{}'", rest),
677        span: Span::new((line + 1) as u32, 1, 0, header.len()),
678    })
679}
680
681/// Parse `.irp sym, val1, val2, ...` header.
682fn parse_irp_args(header: &str, line: usize) -> Result<(String, Vec<String>), AsmError> {
683    let rest = header.strip_prefix(".irp").unwrap_or(header).trim();
684    let (sym, vals_str) = rest.split_once(',').ok_or_else(|| AsmError::Syntax {
685        msg: String::from(".irp requires a symbol and a comma-separated value list"),
686        span: Span::new((line + 1) as u32, 1, 0, header.len()),
687    })?;
688    let sym = sym.trim();
689    let values: Vec<String> = vals_str
690        .split(',')
691        .map(|s| String::from(s.trim()))
692        .filter(|s| !s.is_empty())
693        .collect();
694    Ok((String::from(sym), values))
695}
696
697/// Parse `.irpc sym, chars` header.
698fn parse_irpc_args(header: &str, line: usize) -> Result<(String, String), AsmError> {
699    let rest = header.strip_prefix(".irpc").unwrap_or(header).trim();
700    let (sym, chars) = rest.split_once(',').ok_or_else(|| AsmError::Syntax {
701        msg: String::from(".irpc requires a symbol and a string"),
702        span: Span::new((line + 1) as u32, 1, 0, header.len()),
703    })?;
704    Ok((String::from(sym.trim()), String::from(chars.trim())))
705}
706
707/// Collect lines of a block between `open_directive` and `close_directive`,
708/// handling nesting.
709fn collect_block(
710    lines: &[&str],
711    start: usize,
712    open_kw: &str,
713    close_kw: &str,
714) -> Result<(String, usize), AsmError> {
715    let mut depth = 1usize;
716    let mut body_lines = Vec::new();
717    let mut i = start + 1;
718
719    // All directives that share `.endr` as their terminator.
720    let endr_openers: &[&str] = &[".rept", ".irp", ".irpc"];
721
722    while i < lines.len() {
723        let trimmed = lines[i].trim();
724
725        // Check for nested open — if `.endr` is the terminator we must
726        // count *any* `.rept`/`.irp`/`.irpc` as nesting, not just the
727        // exact `open_kw`.
728        if close_kw == ".endr" {
729            for &opener in endr_openers {
730                if trimmed.starts_with(opener)
731                    && (trimmed.len() == opener.len()
732                        || trimmed.as_bytes().get(opener.len()) == Some(&b' ')
733                        || trimmed.as_bytes().get(opener.len()) == Some(&b'\t'))
734                {
735                    depth += 1;
736                    break;
737                }
738            }
739        } else if trimmed.starts_with(open_kw)
740            && (trimmed.len() == open_kw.len()
741                || trimmed.as_bytes().get(open_kw.len()) == Some(&b' ')
742                || trimmed.as_bytes().get(open_kw.len()) == Some(&b'\t'))
743        {
744            depth += 1;
745        }
746
747        if trimmed == close_kw {
748            depth -= 1;
749            if depth == 0 {
750                return Ok((body_lines.join("\n"), i));
751            }
752        }
753
754        body_lines.push(lines[i]);
755        i += 1;
756    }
757
758    Err(AsmError::Syntax {
759        msg: format!("unterminated {} (missing {})", open_kw, close_kw),
760        span: line_span(start),
761    })
762}
763
764/// Evaluate a conditional directive header.
765fn evaluate_condition(
766    header: &str,
767    symbols: &BTreeMap<String, i128>,
768    line: usize,
769) -> Result<bool, AsmError> {
770    let trimmed = header.trim();
771
772    if let Some(rest) = trimmed.strip_prefix(".ifdef") {
773        let name = rest.trim();
774        return Ok(symbols.contains_key(name));
775    }
776
777    if let Some(rest) = trimmed.strip_prefix(".ifndef") {
778        let name = rest.trim();
779        return Ok(!symbols.contains_key(name));
780    }
781
782    if let Some(rest) = trimmed.strip_prefix(".if") {
783        let expr = rest.trim();
784        return Ok(eval_simple_expr(expr, symbols) != 0);
785    }
786
787    Err(AsmError::Syntax {
788        msg: format!("unrecognized conditional directive: {}", trimmed),
789        span: Span::new((line + 1) as u32, 1, 0, header.len()),
790    })
791}
792
793/// Recursive-descent expression evaluator with proper C-like operator precedence.
794///
795/// Precedence (lowest → highest):
796///  1. `||`  logical OR
797///  2. `&&`  logical AND
798///  3. `|`   bitwise OR
799///  4. `^`   bitwise XOR
800///  5. `&`   bitwise AND
801///  6. `==` `!=`  equality
802///  7. `<` `>` `<=` `>=`  relational
803///  8. `<<` `>>`  shift
804///  9. `+` `-`  additive
805/// 10. `*` `/` `%`  multiplicative
806/// 11. `!` `-` `~`  unary prefix
807/// 12. literals, symbols, `defined()`, `(expr)`
808struct ExprEval<'a> {
809    src: &'a [u8],
810    pos: usize,
811    symbols: &'a BTreeMap<String, i128>,
812}
813
814impl<'a> ExprEval<'a> {
815    fn new(expr: &'a str, symbols: &'a BTreeMap<String, i128>) -> Self {
816        Self {
817            src: expr.as_bytes(),
818            pos: 0,
819            symbols,
820        }
821    }
822
823    fn eval(mut self) -> i128 {
824        self.skip_ws();
825        if self.pos >= self.src.len() {
826            return 0;
827        }
828        self.parse_logical_or()
829    }
830
831    fn skip_ws(&mut self) {
832        while self.pos < self.src.len() && self.src[self.pos].is_ascii_whitespace() {
833            self.pos += 1;
834        }
835    }
836
837    /// Try to consume a two-byte operator token. Returns `true` on match.
838    fn eat2(&mut self, c1: u8, c2: u8) -> bool {
839        self.skip_ws();
840        if self.pos + 1 < self.src.len() && self.src[self.pos] == c1 && self.src[self.pos + 1] == c2
841        {
842            self.pos += 2;
843            true
844        } else {
845            false
846        }
847    }
848
849    // ── precedence 1: || ──────────────────────────────────────────────
850    fn parse_logical_or(&mut self) -> i128 {
851        let mut v = self.parse_logical_and();
852        while self.eat2(b'|', b'|') {
853            let r = self.parse_logical_and();
854            v = if v != 0 || r != 0 { 1 } else { 0 };
855        }
856        v
857    }
858
859    // ── precedence 2: && ──────────────────────────────────────────────
860    fn parse_logical_and(&mut self) -> i128 {
861        let mut v = self.parse_bitwise_or();
862        while self.eat2(b'&', b'&') {
863            let r = self.parse_bitwise_or();
864            v = if v != 0 && r != 0 { 1 } else { 0 };
865        }
866        v
867    }
868
869    // ── precedence 3: | (but not ||) ─────────────────────────────────
870    fn parse_bitwise_or(&mut self) -> i128 {
871        let mut v = self.parse_bitwise_xor();
872        loop {
873            self.skip_ws();
874            if self.pos < self.src.len() && self.src[self.pos] == b'|' {
875                // Distinguish | from ||
876                if self.pos + 1 < self.src.len() && self.src[self.pos + 1] == b'|' {
877                    break;
878                }
879                self.pos += 1;
880                v |= self.parse_bitwise_xor();
881            } else {
882                break;
883            }
884        }
885        v
886    }
887
888    // ── precedence 4: ^ ──────────────────────────────────────────────
889    fn parse_bitwise_xor(&mut self) -> i128 {
890        let mut v = self.parse_bitwise_and();
891        loop {
892            self.skip_ws();
893            if self.pos < self.src.len() && self.src[self.pos] == b'^' {
894                self.pos += 1;
895                v ^= self.parse_bitwise_and();
896            } else {
897                break;
898            }
899        }
900        v
901    }
902
903    // ── precedence 5: & (but not &&) ─────────────────────────────────
904    fn parse_bitwise_and(&mut self) -> i128 {
905        let mut v = self.parse_equality();
906        loop {
907            self.skip_ws();
908            if self.pos < self.src.len() && self.src[self.pos] == b'&' {
909                if self.pos + 1 < self.src.len() && self.src[self.pos + 1] == b'&' {
910                    break;
911                }
912                self.pos += 1;
913                v &= self.parse_equality();
914            } else {
915                break;
916            }
917        }
918        v
919    }
920
921    // ── precedence 6: == != ──────────────────────────────────────────
922    fn parse_equality(&mut self) -> i128 {
923        let mut v = self.parse_relational();
924        loop {
925            if self.eat2(b'=', b'=') {
926                let r = self.parse_relational();
927                v = if v == r { 1 } else { 0 };
928            } else if self.eat2(b'!', b'=') {
929                let r = self.parse_relational();
930                v = if v == r { 0 } else { 1 };
931            } else {
932                break;
933            }
934        }
935        v
936    }
937
938    // ── precedence 7: < > <= >= ──────────────────────────────────────
939    fn parse_relational(&mut self) -> i128 {
940        let mut v = self.parse_shift();
941        loop {
942            if self.eat2(b'<', b'=') {
943                v = if v <= self.parse_shift() { 1 } else { 0 };
944            } else if self.eat2(b'>', b'=') {
945                v = if v >= self.parse_shift() { 1 } else { 0 };
946            } else {
947                self.skip_ws();
948                if self.pos < self.src.len() && self.src[self.pos] == b'<' {
949                    // Not << or <=
950                    if self.pos + 1 < self.src.len()
951                        && (self.src[self.pos + 1] == b'<' || self.src[self.pos + 1] == b'=')
952                    {
953                        break;
954                    }
955                    self.pos += 1;
956                    v = if v < self.parse_shift() { 1 } else { 0 };
957                } else if self.pos < self.src.len() && self.src[self.pos] == b'>' {
958                    if self.pos + 1 < self.src.len()
959                        && (self.src[self.pos + 1] == b'>' || self.src[self.pos + 1] == b'=')
960                    {
961                        break;
962                    }
963                    self.pos += 1;
964                    v = if v > self.parse_shift() { 1 } else { 0 };
965                } else {
966                    break;
967                }
968            }
969        }
970        v
971    }
972
973    // ── precedence 8: << >> ──────────────────────────────────────────
974    fn parse_shift(&mut self) -> i128 {
975        let mut v = self.parse_additive();
976        loop {
977            if self.eat2(b'<', b'<') {
978                let r = self.parse_additive();
979                v = if (0..128).contains(&r) {
980                    v.wrapping_shl(r as u32)
981                } else {
982                    0
983                };
984            } else if self.eat2(b'>', b'>') {
985                let r = self.parse_additive();
986                v = if (0..128).contains(&r) {
987                    v.wrapping_shr(r as u32)
988                } else {
989                    0
990                };
991            } else {
992                break;
993            }
994        }
995        v
996    }
997
998    // ── precedence 9: + - ────────────────────────────────────────────
999    fn parse_additive(&mut self) -> i128 {
1000        let mut v = self.parse_multiplicative();
1001        loop {
1002            self.skip_ws();
1003            if self.pos < self.src.len() && self.src[self.pos] == b'+' {
1004                self.pos += 1;
1005                v = v.wrapping_add(self.parse_multiplicative());
1006            } else if self.pos < self.src.len() && self.src[self.pos] == b'-' {
1007                self.pos += 1;
1008                v = v.wrapping_sub(self.parse_multiplicative());
1009            } else {
1010                break;
1011            }
1012        }
1013        v
1014    }
1015
1016    // ── precedence 10: * / % ─────────────────────────────────────────
1017    fn parse_multiplicative(&mut self) -> i128 {
1018        let mut v = self.parse_unary();
1019        loop {
1020            self.skip_ws();
1021            if self.pos < self.src.len() && self.src[self.pos] == b'*' {
1022                self.pos += 1;
1023                v = v.wrapping_mul(self.parse_unary());
1024            } else if self.pos < self.src.len() && self.src[self.pos] == b'/' {
1025                self.pos += 1;
1026                let r = self.parse_unary();
1027                v = if r != 0 { v / r } else { 0 };
1028            } else if self.pos < self.src.len() && self.src[self.pos] == b'%' {
1029                self.pos += 1;
1030                let r = self.parse_unary();
1031                v = if r != 0 { v % r } else { 0 };
1032            } else {
1033                break;
1034            }
1035        }
1036        v
1037    }
1038
1039    // ── precedence 11: unary ! - ~ ───────────────────────────────────
1040    fn parse_unary(&mut self) -> i128 {
1041        self.skip_ws();
1042        if self.pos < self.src.len() {
1043            match self.src[self.pos] {
1044                // Logical NOT (but not !=)
1045                b'!' if self.pos + 1 >= self.src.len() || self.src[self.pos + 1] != b'=' => {
1046                    self.pos += 1;
1047                    let v = self.parse_unary();
1048                    return if v == 0 { 1 } else { 0 };
1049                }
1050                b'-' => {
1051                    self.pos += 1;
1052                    return self.parse_unary().wrapping_neg();
1053                }
1054                b'~' => {
1055                    self.pos += 1;
1056                    return !self.parse_unary();
1057                }
1058                _ => {}
1059            }
1060        }
1061        self.parse_primary()
1062    }
1063
1064    // ── precedence 12: atoms ─────────────────────────────────────────
1065    fn parse_primary(&mut self) -> i128 {
1066        self.skip_ws();
1067        if self.pos >= self.src.len() {
1068            return 0;
1069        }
1070        let ch = self.src[self.pos];
1071
1072        // Parenthesised sub-expression
1073        if ch == b'(' {
1074            self.pos += 1;
1075            let v = self.parse_logical_or();
1076            self.skip_ws();
1077            if self.pos < self.src.len() && self.src[self.pos] == b')' {
1078                self.pos += 1;
1079            }
1080            return v;
1081        }
1082
1083        // Numeric literal (decimal, 0x, 0b, 0o)
1084        if ch.is_ascii_digit() {
1085            return self.parse_number();
1086        }
1087
1088        // Character literal 'c'
1089        if ch == b'\'' && self.pos + 2 < self.src.len() && self.src[self.pos + 2] == b'\'' {
1090            let c = self.src[self.pos + 1];
1091            self.pos += 3;
1092            return c as i128;
1093        }
1094
1095        // Identifier: symbol name or `defined()`
1096        if ch.is_ascii_alphabetic() || ch == b'_' || ch == b'.' {
1097            let start = self.pos;
1098            while self.pos < self.src.len() {
1099                let c = self.src[self.pos];
1100                if c.is_ascii_alphanumeric() || c == b'_' || c == b'.' {
1101                    self.pos += 1;
1102                } else {
1103                    break;
1104                }
1105            }
1106            let name = core::str::from_utf8(&self.src[start..self.pos]).unwrap_or("");
1107
1108            // `defined(sym)` pseudo-function
1109            if name == "defined" {
1110                self.skip_ws();
1111                if self.pos < self.src.len() && self.src[self.pos] == b'(' {
1112                    self.pos += 1;
1113                    self.skip_ws();
1114                    let s = self.pos;
1115                    while self.pos < self.src.len() {
1116                        let c = self.src[self.pos];
1117                        if c.is_ascii_alphanumeric() || c == b'_' || c == b'.' {
1118                            self.pos += 1;
1119                        } else {
1120                            break;
1121                        }
1122                    }
1123                    let sym = core::str::from_utf8(&self.src[s..self.pos]).unwrap_or("");
1124                    self.skip_ws();
1125                    if self.pos < self.src.len() && self.src[self.pos] == b')' {
1126                        self.pos += 1;
1127                    }
1128                    return if self.symbols.contains_key(sym) { 1 } else { 0 };
1129                }
1130            }
1131
1132            if let Some(&val) = self.symbols.get(name) {
1133                return val;
1134            }
1135            return 0; // unknown symbol → 0
1136        }
1137
1138        0
1139    }
1140
1141    /// Parse a numeric literal at the current position.
1142    fn parse_number(&mut self) -> i128 {
1143        if self.src[self.pos] == b'0' && self.pos + 1 < self.src.len() {
1144            match self.src[self.pos + 1] {
1145                b'x' | b'X' => {
1146                    self.pos += 2;
1147                    let start = self.pos;
1148                    while self.pos < self.src.len() && self.src[self.pos].is_ascii_hexdigit() {
1149                        self.pos += 1;
1150                    }
1151                    let s = core::str::from_utf8(&self.src[start..self.pos]).unwrap_or("0");
1152                    return i128::from_str_radix(s, 16).unwrap_or(0);
1153                }
1154                b'b' | b'B' => {
1155                    self.pos += 2;
1156                    let start = self.pos;
1157                    while self.pos < self.src.len() && matches!(self.src[self.pos], b'0' | b'1') {
1158                        self.pos += 1;
1159                    }
1160                    let s = core::str::from_utf8(&self.src[start..self.pos]).unwrap_or("0");
1161                    return i128::from_str_radix(s, 2).unwrap_or(0);
1162                }
1163                b'o' | b'O' => {
1164                    self.pos += 2;
1165                    let start = self.pos;
1166                    while self.pos < self.src.len() && matches!(self.src[self.pos], b'0'..=b'7') {
1167                        self.pos += 1;
1168                    }
1169                    let s = core::str::from_utf8(&self.src[start..self.pos]).unwrap_or("0");
1170                    return i128::from_str_radix(s, 8).unwrap_or(0);
1171                }
1172                _ => {}
1173            }
1174        }
1175        // Plain decimal
1176        let start = self.pos;
1177        while self.pos < self.src.len() && self.src[self.pos].is_ascii_digit() {
1178            self.pos += 1;
1179        }
1180        let s = core::str::from_utf8(&self.src[start..self.pos]).unwrap_or("0");
1181        s.parse::<i128>().unwrap_or(0)
1182    }
1183}
1184
1185/// Evaluate a simple integer expression (with symbol lookup).
1186///
1187/// Uses a recursive-descent parser with full C-like operator precedence.
1188/// Supports all arithmetic, bitwise, shift, logical, and comparison operators,
1189/// parenthesised sub-expressions, `defined()`, and numeric/symbol atoms.
1190fn eval_simple_expr(expr: &str, symbols: &BTreeMap<String, i128>) -> i128 {
1191    ExprEval::new(expr.trim(), symbols).eval()
1192}
1193
1194/// Try to parse a symbol definition from `.equ name, value` or `name = value`.
1195fn try_parse_symbol_def(line: &str) -> Option<(String, i128)> {
1196    let trimmed = line.trim();
1197
1198    // `.equ name, value` or `.set name, value`
1199    for prefix in &[".equ ", ".set "] {
1200        if let Some(rest) = trimmed.strip_prefix(prefix) {
1201            let rest = rest.trim();
1202            if let Some((name, val_str)) = rest.split_once(',') {
1203                if let Ok(val) = parse_int_literal(val_str.trim()) {
1204                    return Some((String::from(name.trim()), val));
1205                }
1206            }
1207        }
1208    }
1209
1210    // `name = value`
1211    if let Some((name, val_str)) = trimmed.split_once('=') {
1212        let name = name.trim();
1213        let val_str = val_str.trim();
1214        // Must not start with '=' (that would be '==')
1215        if !val_str.is_empty()
1216            && !val_str.starts_with('=')
1217            && name.chars().all(|c| c.is_alphanumeric() || c == '_')
1218        {
1219            if let Ok(val) = parse_int_literal(val_str) {
1220                return Some((String::from(name), val));
1221            }
1222        }
1223    }
1224
1225    None
1226}
1227
1228/// Parse an integer literal (decimal, hex, octal, binary).
1229fn parse_int_literal(s: &str) -> Result<i128, ()> {
1230    let s = s.trim();
1231    if let Some(hex) = s.strip_prefix("0x").or_else(|| s.strip_prefix("0X")) {
1232        i128::from_str_radix(hex, 16).map_err(|_| ())
1233    } else if let Some(bin) = s.strip_prefix("0b").or_else(|| s.strip_prefix("0B")) {
1234        i128::from_str_radix(bin, 2).map_err(|_| ())
1235    } else if let Some(oct) = s.strip_prefix("0o").or_else(|| s.strip_prefix("0O")) {
1236        i128::from_str_radix(oct, 8).map_err(|_| ())
1237    } else {
1238        s.parse::<i128>().map_err(|_| ())
1239    }
1240}
1241
1242/// Create a dummy span for a given line index.
1243fn line_span(line: usize) -> Span {
1244    Span::new((line + 1) as u32, 1, 0, 0)
1245}
1246
1247#[cfg(test)]
1248mod tests {
1249    use super::*;
1250
1251    // === Macro definition and expansion ===
1252
1253    #[test]
1254    fn macro_simple_expansion() {
1255        let mut pp = Preprocessor::new();
1256        let source = "\
1257.macro push_pair r1, r2
1258    push \\r1
1259    push \\r2
1260.endm
1261push_pair rax, rbx
1262";
1263        let result = pp.process(source).unwrap();
1264        assert!(result.contains("push rax"));
1265        assert!(result.contains("push rbx"));
1266    }
1267
1268    #[test]
1269    fn macro_with_defaults() {
1270        let mut pp = Preprocessor::new();
1271        let source = "\
1272.macro load_imm reg=rax, val=0
1273    mov \\reg, \\val
1274.endm
1275load_imm
1276load_imm rcx, 42
1277";
1278        let result = pp.process(source).unwrap();
1279        assert!(result.contains("mov rax, 0"));
1280        assert!(result.contains("mov rcx, 42"));
1281    }
1282
1283    #[test]
1284    fn macro_unique_labels() {
1285        let mut pp = Preprocessor::new();
1286        let source = "\
1287.macro my_loop
1288    jmp label_\\@
1289label_\\@:
1290.endm
1291my_loop
1292my_loop
1293";
1294        let result = pp.process(source).unwrap();
1295        assert!(result.contains("label_0"));
1296        assert!(result.contains("label_1"));
1297    }
1298
1299    #[test]
1300    fn macro_recursion_limit() {
1301        let mut pp = Preprocessor::new();
1302        let source = "\
1303.macro recurse
1304    nop
1305    recurse
1306.endm
1307recurse
1308";
1309        let err = pp.process(source).unwrap_err();
1310        match err {
1311            AsmError::ResourceLimitExceeded { resource, .. } => {
1312                assert!(resource.contains("recursion"));
1313            }
1314            _ => panic!("expected ResourceLimitExceeded, got {:?}", err),
1315        }
1316    }
1317
1318    #[test]
1319    fn macro_vararg() {
1320        let mut pp = Preprocessor::new();
1321        let source = "\
1322.macro pushall regs:vararg
1323    # push \\regs
1324.endm
1325pushall rax, rbx, rcx
1326";
1327        let result = pp.process(source).unwrap();
1328        assert!(result.contains("rax, rbx, rcx"));
1329    }
1330
1331    #[test]
1332    fn macro_nested_endm() {
1333        let mut pp = Preprocessor::new();
1334        // Macro containing a nested macro definition
1335        let source = "\
1336.macro outer
1337    nop
1338.endm
1339outer
1340";
1341        let result = pp.process(source).unwrap();
1342        assert!(result.contains("nop"));
1343    }
1344
1345    // === .rept ===
1346
1347    #[test]
1348    fn rept_basic() {
1349        let mut pp = Preprocessor::new();
1350        let source = "\
1351.rept 3
1352    nop
1353.endr
1354";
1355        let result = pp.process(source).unwrap();
1356        let nop_count = result.matches("nop").count();
1357        assert_eq!(nop_count, 3);
1358    }
1359
1360    #[test]
1361    fn rept_zero() {
1362        let mut pp = Preprocessor::new();
1363        let source = "\
1364.rept 0
1365    nop
1366.endr
1367";
1368        let result = pp.process(source).unwrap();
1369        assert!(!result.contains("nop"));
1370    }
1371
1372    #[test]
1373    fn rept_nested() {
1374        let mut pp = Preprocessor::new();
1375        let source = "\
1376.rept 2
1377.rept 3
1378    nop
1379.endr
1380.endr
1381";
1382        let result = pp.process(source).unwrap();
1383        let nop_count = result.matches("nop").count();
1384        assert_eq!(nop_count, 6);
1385    }
1386
1387    // === .irp ===
1388
1389    #[test]
1390    fn irp_basic() {
1391        let mut pp = Preprocessor::new();
1392        let source = "\
1393.irp reg, rax, rbx, rcx
1394    push \\reg
1395.endr
1396";
1397        let result = pp.process(source).unwrap();
1398        assert!(result.contains("push rax"));
1399        assert!(result.contains("push rbx"));
1400        assert!(result.contains("push rcx"));
1401    }
1402
1403    // === .irpc ===
1404
1405    #[test]
1406    fn irpc_basic() {
1407        let mut pp = Preprocessor::new();
1408        let source = "\
1409.irpc c, abc
1410    .byte '\\c'
1411.endr
1412";
1413        let result = pp.process(source).unwrap();
1414        assert!(result.contains("'a'"));
1415        assert!(result.contains("'b'"));
1416        assert!(result.contains("'c'"));
1417    }
1418
1419    // === Conditional assembly ===
1420
1421    #[test]
1422    fn if_true() {
1423        let mut pp = Preprocessor::new();
1424        let source = "\
1425.if 1
1426    nop
1427.endif
1428";
1429        let result = pp.process(source).unwrap();
1430        assert!(result.contains("nop"));
1431    }
1432
1433    #[test]
1434    fn if_false() {
1435        let mut pp = Preprocessor::new();
1436        let source = "\
1437.if 0
1438    nop
1439.endif
1440";
1441        let result = pp.process(source).unwrap();
1442        assert!(!result.contains("nop"));
1443    }
1444
1445    #[test]
1446    fn if_else() {
1447        let mut pp = Preprocessor::new();
1448        let source = "\
1449.if 0
1450    mov rax, 1
1451.else
1452    mov rax, 2
1453.endif
1454";
1455        let result = pp.process(source).unwrap();
1456        assert!(!result.contains("mov rax, 1"));
1457        assert!(result.contains("mov rax, 2"));
1458    }
1459
1460    #[test]
1461    fn if_elseif() {
1462        let mut pp = Preprocessor::new();
1463        let source = "\
1464.if 0
1465    mov rax, 1
1466.elseif 1
1467    mov rax, 2
1468.else
1469    mov rax, 3
1470.endif
1471";
1472        let result = pp.process(source).unwrap();
1473        assert!(!result.contains("mov rax, 1"));
1474        assert!(result.contains("mov rax, 2"));
1475        assert!(!result.contains("mov rax, 3"));
1476    }
1477
1478    #[test]
1479    fn ifdef_defined() {
1480        let mut pp = Preprocessor::new();
1481        pp.define_symbol("MY_FLAG", 1);
1482        let source = "\
1483.ifdef MY_FLAG
1484    nop
1485.endif
1486";
1487        let result = pp.process(source).unwrap();
1488        assert!(result.contains("nop"));
1489    }
1490
1491    #[test]
1492    fn ifdef_undefined() {
1493        let mut pp = Preprocessor::new();
1494        let source = "\
1495.ifdef UNDEFINED_FLAG
1496    nop
1497.endif
1498";
1499        let result = pp.process(source).unwrap();
1500        assert!(!result.contains("nop"));
1501    }
1502
1503    #[test]
1504    fn ifndef_undefined() {
1505        let mut pp = Preprocessor::new();
1506        let source = "\
1507.ifndef MY_FLAG
1508    nop
1509.endif
1510";
1511        let result = pp.process(source).unwrap();
1512        assert!(result.contains("nop"));
1513    }
1514
1515    #[test]
1516    fn nested_conditionals() {
1517        let mut pp = Preprocessor::new();
1518        pp.define_symbol("OUTER", 1);
1519        pp.define_symbol("INNER", 1);
1520        let source = "\
1521.ifdef OUTER
1522    .ifdef INNER
1523        nop
1524    .endif
1525.endif
1526";
1527        let result = pp.process(source).unwrap();
1528        assert!(result.contains("nop"));
1529    }
1530
1531    #[test]
1532    fn if_expression_with_symbols() {
1533        let mut pp = Preprocessor::new();
1534        pp.define_symbol("X", 5);
1535        let source = "\
1536.if X > 3
1537    nop
1538.endif
1539";
1540        let result = pp.process(source).unwrap();
1541        assert!(result.contains("nop"));
1542    }
1543
1544    #[test]
1545    fn equ_tracks_symbols() {
1546        let mut pp = Preprocessor::new();
1547        let source = "\
1548.equ MY_CONST, 42
1549.ifdef MY_CONST
1550    nop
1551.endif
1552";
1553        let result = pp.process(source).unwrap();
1554        assert!(result.contains("nop"));
1555        // The .equ line also passes through for the parser
1556        assert!(result.contains(".equ MY_CONST, 42"));
1557    }
1558
1559    #[test]
1560    fn if_defined_function() {
1561        let mut pp = Preprocessor::new();
1562        pp.define_symbol("X", 1);
1563        let source = "\
1564.if defined(X)
1565    nop
1566.endif
1567";
1568        let result = pp.process(source).unwrap();
1569        assert!(result.contains("nop"));
1570    }
1571
1572    // === Error cases ===
1573
1574    #[test]
1575    fn unterminated_macro() {
1576        let mut pp = Preprocessor::new();
1577        let source = ".macro foo\n    nop\n";
1578        let err = pp.process(source).unwrap_err();
1579        match err {
1580            AsmError::Syntax { msg, .. } => {
1581                assert!(msg.contains("unterminated .macro"));
1582            }
1583            _ => panic!("expected Syntax error"),
1584        }
1585    }
1586
1587    #[test]
1588    fn unterminated_rept() {
1589        let mut pp = Preprocessor::new();
1590        let source = ".rept 3\n    nop\n";
1591        let err = pp.process(source).unwrap_err();
1592        match err {
1593            AsmError::Syntax { msg, .. } => {
1594                assert!(msg.contains("unterminated"));
1595            }
1596            _ => panic!("expected Syntax error"),
1597        }
1598    }
1599
1600    #[test]
1601    fn unterminated_conditional() {
1602        let mut pp = Preprocessor::new();
1603        let source = ".if 1\n    nop\n";
1604        let err = pp.process(source).unwrap_err();
1605        match err {
1606            AsmError::Syntax { msg, .. } => {
1607                assert!(msg.contains("unterminated"));
1608            }
1609            _ => panic!("expected Syntax error"),
1610        }
1611    }
1612
1613    #[test]
1614    fn iteration_limit() {
1615        let mut pp = Preprocessor::new();
1616        let source = ".rept 200000\n    nop\n.endr\n";
1617        let err = pp.process(source).unwrap_err();
1618        match err {
1619            AsmError::ResourceLimitExceeded { resource, .. } => {
1620                assert!(resource.contains("iteration"));
1621            }
1622            _ => panic!("expected ResourceLimitExceeded"),
1623        }
1624    }
1625
1626    // === Expression evaluator tests ===
1627
1628    /// Helper: evaluate expression with given symbols.
1629    fn eval(expr: &str) -> i128 {
1630        let syms = BTreeMap::new();
1631        super::eval_simple_expr(expr, &syms)
1632    }
1633
1634    fn eval_with(expr: &str, syms: &BTreeMap<String, i128>) -> i128 {
1635        super::eval_simple_expr(expr, syms)
1636    }
1637
1638    #[test]
1639    fn expr_decimal_literals() {
1640        assert_eq!(eval("0"), 0);
1641        assert_eq!(eval("42"), 42);
1642        assert_eq!(eval("123456789"), 123_456_789);
1643    }
1644
1645    #[test]
1646    fn expr_hex_literals() {
1647        assert_eq!(eval("0xFF"), 255);
1648        assert_eq!(eval("0x10"), 16);
1649        assert_eq!(eval("0XAB"), 0xAB);
1650    }
1651
1652    #[test]
1653    fn expr_binary_literals() {
1654        assert_eq!(eval("0b1010"), 10);
1655        assert_eq!(eval("0B11111111"), 255);
1656    }
1657
1658    #[test]
1659    fn expr_octal_literals() {
1660        assert_eq!(eval("0o77"), 63);
1661        assert_eq!(eval("0O10"), 8);
1662    }
1663
1664    #[test]
1665    fn expr_char_literal() {
1666        assert_eq!(eval("'A'"), 65);
1667        assert_eq!(eval("'0'"), 48);
1668    }
1669
1670    #[test]
1671    fn expr_addition() {
1672        assert_eq!(eval("1 + 2"), 3);
1673        assert_eq!(eval("10+20+30"), 60);
1674    }
1675
1676    #[test]
1677    fn expr_subtraction() {
1678        assert_eq!(eval("10 - 3"), 7);
1679        assert_eq!(eval("100 - 50 - 25"), 25);
1680    }
1681
1682    #[test]
1683    fn expr_multiplication() {
1684        assert_eq!(eval("3 * 4"), 12);
1685        assert_eq!(eval("2 * 3 * 5"), 30);
1686    }
1687
1688    #[test]
1689    fn expr_division() {
1690        assert_eq!(eval("12 / 4"), 3);
1691        assert_eq!(eval("100 / 10 / 2"), 5);
1692        // Division by zero → 0
1693        assert_eq!(eval("42 / 0"), 0);
1694    }
1695
1696    #[test]
1697    fn expr_modulo() {
1698        assert_eq!(eval("10 % 3"), 1);
1699        assert_eq!(eval("17 % 5"), 2);
1700        assert_eq!(eval("42 % 0"), 0);
1701    }
1702
1703    #[test]
1704    fn expr_precedence_mul_over_add() {
1705        assert_eq!(eval("2 + 3 * 4"), 14);
1706        assert_eq!(eval("3 * 4 + 2"), 14);
1707        assert_eq!(eval("10 - 2 * 3"), 4);
1708    }
1709
1710    #[test]
1711    fn expr_parentheses() {
1712        assert_eq!(eval("(2 + 3) * 4"), 20);
1713        assert_eq!(eval("((1 + 2) * (3 + 4))"), 21);
1714        assert_eq!(eval("(10)"), 10);
1715    }
1716
1717    #[test]
1718    fn expr_nested_parentheses() {
1719        assert_eq!(eval("((2 + 3) * (4 - 1))"), 15);
1720        assert_eq!(eval("(((5)))"), 5);
1721    }
1722
1723    #[test]
1724    fn expr_bitwise_and() {
1725        assert_eq!(eval("0xFF & 0x0F"), 0x0F);
1726        assert_eq!(eval("0b1010 & 0b1100"), 0b1000);
1727    }
1728
1729    #[test]
1730    fn expr_bitwise_or() {
1731        assert_eq!(eval("0x0F | 0xF0"), 0xFF);
1732        assert_eq!(eval("0b1010 | 0b0101"), 0b1111);
1733    }
1734
1735    #[test]
1736    fn expr_bitwise_xor() {
1737        assert_eq!(eval("0xFF ^ 0x0F"), 0xF0);
1738        assert_eq!(eval("0b1010 ^ 0b1100"), 0b0110);
1739    }
1740
1741    #[test]
1742    fn expr_bitwise_not() {
1743        // ~0 in i128 is all ones = -1
1744        assert_eq!(eval("~0"), -1);
1745        assert_eq!(eval("~0xFF & 0xFF"), 0);
1746    }
1747
1748    #[test]
1749    fn expr_shift_left() {
1750        assert_eq!(eval("1 << 8"), 256);
1751        assert_eq!(eval("0xFF << 4"), 0xFF0);
1752    }
1753
1754    #[test]
1755    fn expr_shift_right() {
1756        assert_eq!(eval("256 >> 8"), 1);
1757        assert_eq!(eval("0xFF0 >> 4"), 0xFF);
1758    }
1759
1760    #[test]
1761    fn expr_logical_and() {
1762        assert_eq!(eval("1 && 1"), 1);
1763        assert_eq!(eval("1 && 0"), 0);
1764        assert_eq!(eval("0 && 1"), 0);
1765        assert_eq!(eval("0 && 0"), 0);
1766    }
1767
1768    #[test]
1769    fn expr_logical_or() {
1770        assert_eq!(eval("1 || 1"), 1);
1771        assert_eq!(eval("1 || 0"), 1);
1772        assert_eq!(eval("0 || 1"), 1);
1773        assert_eq!(eval("0 || 0"), 0);
1774    }
1775
1776    #[test]
1777    fn expr_logical_not() {
1778        assert_eq!(eval("!0"), 1);
1779        assert_eq!(eval("!1"), 0);
1780        assert_eq!(eval("!42"), 0);
1781    }
1782
1783    #[test]
1784    fn expr_equality() {
1785        assert_eq!(eval("5 == 5"), 1);
1786        assert_eq!(eval("5 == 6"), 0);
1787        assert_eq!(eval("5 != 6"), 1);
1788        assert_eq!(eval("5 != 5"), 0);
1789    }
1790
1791    #[test]
1792    fn expr_relational() {
1793        assert_eq!(eval("3 < 5"), 1);
1794        assert_eq!(eval("5 < 3"), 0);
1795        assert_eq!(eval("5 > 3"), 1);
1796        assert_eq!(eval("3 > 5"), 0);
1797        assert_eq!(eval("5 <= 5"), 1);
1798        assert_eq!(eval("5 <= 6"), 1);
1799        assert_eq!(eval("6 <= 5"), 0);
1800        assert_eq!(eval("5 >= 5"), 1);
1801        assert_eq!(eval("6 >= 5"), 1);
1802        assert_eq!(eval("5 >= 6"), 0);
1803    }
1804
1805    #[test]
1806    fn expr_unary_minus() {
1807        assert_eq!(eval("-1"), -1);
1808        assert_eq!(eval("-(-5)"), 5);
1809        assert_eq!(eval("3 + -2"), 1);
1810        assert_eq!(eval("3 - -2"), 5);
1811    }
1812
1813    #[test]
1814    fn expr_mixed_precedence() {
1815        // Shift lower than add: 1 + 2 << 3 == (1+2) << 3 ... NO
1816        // Actually: << is higher than +: 1 + (2 << 3) = 1 + 16 = 17
1817        assert_eq!(eval("1 + 2 << 3"), 24); // (1+2)<<3, since shift is HIGHER than add... wait
1818                                            // Let me think about this. In C, << is higher precedence (binds tighter) than +.
1819                                            // But in our parser, additive is level 9 and shift is level 8 (lower number = lower precedence).
1820                                            // additive calls parse_multiplicative, shift calls parse_additive.
1821                                            // Wait - that's wrong. Let me re-check.
1822                                            // Actually: parse_additive calls parse_multiplicative, and parse_shift calls parse_additive.
1823                                            // So shift calls additive which calls multiplicative. This means additive binds tighter than shift.
1824                                            // That matches C precedence where + binds tighter than <<.
1825                                            // So 1 + 2 << 3 = (1+2) << 3 = 3 << 3 = 24. Correct for C.
1826        assert_eq!(eval("1 + 2 << 3"), 24);
1827
1828        // Comparison: == is lower than +
1829        assert_eq!(eval("2 + 3 == 5"), 1);
1830        assert_eq!(eval("2 + 3 == 6"), 0);
1831
1832        // Logical: && is lower than ==
1833        assert_eq!(eval("1 == 1 && 2 == 2"), 1);
1834        assert_eq!(eval("1 == 1 && 2 == 3"), 0);
1835
1836        // || is the lowest
1837        assert_eq!(eval("0 && 1 || 1"), 1);
1838        assert_eq!(eval("1 || 0 && 0"), 1);
1839    }
1840
1841    #[test]
1842    fn expr_complex_bitwise() {
1843        // Page-align: addr & ~0xFFF
1844        // Can't test with large addresses easily, but logic works
1845        assert_eq!(eval("0x1234 & ~0xFFF & 0xFFFF"), 0x1000);
1846        // Flag test
1847        assert_eq!(eval("(0x03 & 0x01) != 0"), 1);
1848        assert_eq!(eval("(0x02 & 0x01) != 0"), 0);
1849    }
1850
1851    #[test]
1852    fn expr_symbols() {
1853        let mut syms = BTreeMap::new();
1854        syms.insert(String::from("X"), 10);
1855        syms.insert(String::from("Y"), 20);
1856        assert_eq!(eval_with("X + Y", &syms), 30);
1857        assert_eq!(eval_with("X * Y", &syms), 200);
1858        assert_eq!(eval_with("(X + Y) * 2", &syms), 60);
1859    }
1860
1861    #[test]
1862    fn expr_defined_function() {
1863        let mut syms = BTreeMap::new();
1864        syms.insert(String::from("FOO"), 1);
1865        assert_eq!(eval_with("defined(FOO)", &syms), 1);
1866        assert_eq!(eval_with("defined(BAR)", &syms), 0);
1867        assert_eq!(eval_with("defined(FOO) && defined(BAR)", &syms), 0);
1868        assert_eq!(eval_with("defined(FOO) || defined(BAR)", &syms), 1);
1869    }
1870
1871    #[test]
1872    fn expr_regression_0x_minus() {
1873        // Old parser choked on "0x10 - 1" because rsplit_once('-') split at "0x10"
1874        assert_eq!(eval("0x10 - 1"), 15);
1875        assert_eq!(eval("0xFF - 0xF0"), 15);
1876    }
1877
1878    #[test]
1879    fn expr_whitespace_tolerance() {
1880        assert_eq!(eval("  42  "), 42);
1881        assert_eq!(eval("  1  +  2  "), 3);
1882        assert_eq!(eval(" ( 1 + 2 ) * 3 "), 9);
1883    }
1884
1885    #[test]
1886    fn expr_empty() {
1887        assert_eq!(eval(""), 0);
1888        assert_eq!(eval("   "), 0);
1889    }
1890
1891    #[test]
1892    fn expr_if_mul_integrated() {
1893        // This was the data-corruption bug: `.if 2 * 3 == 6` silently failed
1894        let mut pp = Preprocessor::new();
1895        let source = "\
1896.if 2 * 3 == 6
1897    nop
1898.endif
1899";
1900        let result = pp.process(source).unwrap();
1901        assert!(result.contains("nop"), "2*3==6 should be true");
1902    }
1903
1904    #[test]
1905    fn expr_if_parenthesised_integrated() {
1906        let mut pp = Preprocessor::new();
1907        let source = "\
1908.if (1 + 2) * 4 == 12
1909    mov eax, 1
1910.endif
1911";
1912        let result = pp.process(source).unwrap();
1913        assert!(result.contains("mov eax, 1"));
1914    }
1915
1916    #[test]
1917    fn expr_if_shift_integrated() {
1918        let mut pp = Preprocessor::new();
1919        let source = "\
1920.if 1 << 4 == 16
1921    nop
1922.endif
1923";
1924        let result = pp.process(source).unwrap();
1925        assert!(result.contains("nop"), "1<<4 should equal 16");
1926    }
1927
1928    #[test]
1929    fn expr_if_bitwise_and_integrated() {
1930        let mut pp = Preprocessor::new();
1931        let source = "\
1932.equ FLAGS, 0x07
1933.if FLAGS & 0x02
1934    nop
1935.endif
1936";
1937        let result = pp.process(source).unwrap();
1938        assert!(result.contains("nop"), "0x07 & 0x02 should be non-zero");
1939    }
1940
1941    #[test]
1942    fn expr_if_logical_and_integrated() {
1943        let mut pp = Preprocessor::new();
1944        pp.define_symbol("A", 1);
1945        pp.define_symbol("B", 1);
1946        let source = "\
1947.if defined(A) && defined(B)
1948    nop
1949.endif
1950";
1951        let result = pp.process(source).unwrap();
1952        assert!(result.contains("nop"), "both A and B defined");
1953    }
1954
1955    #[test]
1956    fn expr_if_logical_or_integrated() {
1957        let mut pp = Preprocessor::new();
1958        pp.define_symbol("A", 1);
1959        let source = "\
1960.if defined(A) || defined(B)
1961    nop
1962.endif
1963";
1964        let result = pp.process(source).unwrap();
1965        assert!(result.contains("nop"), "A is defined so OR should be true");
1966    }
1967
1968    #[test]
1969    fn expr_elseif_with_operators() {
1970        let mut pp = Preprocessor::new();
1971        pp.define_symbol("MODE", 2);
1972        let source = "\
1973.if MODE * 2 == 2
1974    wrong
1975.elseif MODE * 2 == 4
1976    correct
1977.else
1978    also_wrong
1979.endif
1980";
1981        let result = pp.process(source).unwrap();
1982        assert!(!result.contains("wrong"));
1983        assert!(result.contains("correct"));
1984    }
1985}