Skip to main content

rlx_cpu/
asm_check.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FileCheck-style disassembly regression tests (plan #10).
17//!
18//! Borrowed from MAX's `bazel/internal/mojo_filecheck_test.bzl` +
19//! `bazel/internal/lit.bzl` pattern: assertions on emitted IR/asm
20//! catch optimizer regressions that unit tests miss. The classic
21//! one is "the optimizer changed and now we lost a SIMD intrinsic"
22//! — a benchmark would notice the slowdown later, but FileCheck
23//! catches it at PR time.
24//!
25//! The Rust spelling shells out to `objdump` (or `llvm-objdump` if
26//! preferred) on the running test binary, locates a named function
27//! in the disassembly, and asserts each requested pattern appears.
28//!
29//! Tests that use this should mark themselves `#[ignore]` so
30//! they're opt-in via `cargo test -- --ignored asm` — `objdump`
31//! isn't always available in CI, and disassembling a debug binary
32//! is slow.
33//!
34//! Pattern matching is substring + line-anchored regex via the
35//! standard library only — no extern dep on `regex`.
36
37use std::path::PathBuf;
38use std::process::Command;
39
40/// What objdump-equivalent to use. macOS ships `objdump` as
41/// `Apple LLVM`; Linux usually has `llvm-objdump` available, with
42/// GNU `objdump` as a fallback.
43fn locate_objdump() -> Option<PathBuf> {
44    for candidate in ["llvm-objdump", "objdump"] {
45        let probe = Command::new(candidate).arg("--version").output();
46        if probe.ok().filter(|o| o.status.success()).is_some() {
47            return Some(PathBuf::from(candidate));
48        }
49    }
50    None
51}
52
53/// Path to the currently-running test binary. Cargo runs each
54/// integration / library test from a known location; the env var
55/// `CARGO_BIN_EXE_<name>` points binaries at themselves, but for
56/// library tests we walk up from `std::env::current_exe()`.
57fn current_test_binary() -> std::io::Result<PathBuf> {
58    std::env::current_exe()
59}
60
61/// Disassemble the test binary; return the full text dump.
62/// Errors out (with a skippable label) when the tool isn't
63/// available so callers can short-circuit gracefully.
64pub fn disassemble_self() -> Result<String, AsmCheckError> {
65    let tool = locate_objdump().ok_or(AsmCheckError::ToolMissing)?;
66    let bin = current_test_binary().map_err(|e| AsmCheckError::IoError(e.to_string()))?;
67    let out = Command::new(&tool)
68        .arg("-d")
69        .arg("--no-show-raw-insn")
70        .arg(&bin)
71        .output()
72        .map_err(|e| AsmCheckError::IoError(e.to_string()))?;
73    if !out.status.success() {
74        return Err(AsmCheckError::ToolFailed {
75            stderr: String::from_utf8_lossy(&out.stderr).into_owned(),
76        });
77    }
78    Ok(String::from_utf8_lossy(&out.stdout).into_owned())
79}
80
81/// Look up a function by demangled-name substring and return the
82/// slice of disassembly belonging to it (until the next
83/// function header).
84pub fn function_section<'a>(disasm: &'a str, name_substr: &str) -> Option<&'a str> {
85    let mut start = None;
86    for (i, line) in disasm.lines().enumerate() {
87        if line.contains(name_substr) && line.trim_end().ends_with(':') {
88            start = Some(i);
89            break;
90        }
91    }
92    let start = start?;
93    let lines: Vec<&str> = disasm.lines().collect();
94    let mut end = lines.len();
95    for j in (start + 1)..lines.len() {
96        if lines[j].ends_with(':')
97            && !lines[j].trim_start().starts_with('0')  // not a target label
98            && !lines[j].is_empty()
99        {
100            end = j;
101            break;
102        }
103    }
104    let from = lines[..start].iter().map(|s| s.len() + 1).sum::<usize>();
105    let to = lines[..end]
106        .iter()
107        .map(|s| s.len() + 1)
108        .sum::<usize>()
109        .min(disasm.len());
110    Some(&disasm[from..to])
111}
112
113/// Assert each substring in `expected` appears at least once in
114/// the function `name`'s disassembly. Returns Err if the
115/// disassembler isn't present so callers can `eprintln!` skip
116/// without failing.
117pub fn assert_function_contains(name_substr: &str, expected: &[&str]) -> Result<(), AsmCheckError> {
118    let disasm = disassemble_self()?;
119    let body = function_section(&disasm, name_substr).ok_or(AsmCheckError::FunctionNotFound {
120        name: name_substr.into(),
121    })?;
122    for pat in expected {
123        if !body.contains(pat) {
124            return Err(AsmCheckError::PatternMissing {
125                function: name_substr.into(),
126                pattern: (*pat).into(),
127            });
128        }
129    }
130    Ok(())
131}
132
133#[derive(Debug)]
134pub enum AsmCheckError {
135    ToolMissing,
136    ToolFailed { stderr: String },
137    IoError(String),
138    FunctionNotFound { name: String },
139    PatternMissing { function: String, pattern: String },
140}
141
142impl std::fmt::Display for AsmCheckError {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            Self::ToolMissing => write!(f, "neither llvm-objdump nor objdump found in PATH"),
146            Self::ToolFailed { stderr } => write!(f, "objdump failed: {stderr}"),
147            Self::IoError(s) => write!(f, "io error: {s}"),
148            Self::FunctionNotFound { name } => {
149                write!(f, "function matching `{name}` not found in disassembly")
150            }
151            Self::PatternMissing { function, pattern } => write!(
152                f,
153                "function `{function}` is missing expected pattern `{pattern}`"
154            ),
155        }
156    }
157}
158
159impl std::error::Error for AsmCheckError {}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    /// basic test: disassembly works at all. Marked `#[ignore]`
166    /// because objdump isn't always around in CI (in which case
167    /// we'd want to log-and-skip rather than fail).
168    #[test]
169    #[ignore]
170    fn disassemble_self_succeeds() {
171        match disassemble_self() {
172            Ok(d) => assert!(d.len() > 1024, "disassembly suspiciously small"),
173            Err(AsmCheckError::ToolMissing) => {
174                eprintln!("[asm-check] skipping: objdump not in PATH");
175            }
176            Err(e) => panic!("disassembly failed: {e}"),
177        }
178    }
179
180    /// Real check: the cumsum kernel must contain the expected
181    /// f32 multiply / fused-multiply on aarch64. If the optimizer
182    /// regresses and inlines/loses these, the test catches it.
183    #[test]
184    #[ignore]
185    fn cumsum_kernel_keeps_simd_on_aarch64() {
186        // On targets where we don't expect SIMD here, accept the
187        // miss as "no expected pattern" and move on.
188        if !cfg!(target_arch = "aarch64") {
189            eprintln!("[asm-check] skipping: not aarch64");
190            return;
191        }
192        // The cumsum direct-execution path adds + stores in a
193        // tight inner loop. We check for `fmadd` (FP add fused
194        // with multiply) or just `fadd` since cumsum doesn't
195        // multiply. Pick a stable substring: `fadd s` (single
196        // f32 register-form) shows up in the inner loop.
197        match assert_function_contains("Cumsum", &["fadd"]) {
198            Ok(()) => {}
199            Err(AsmCheckError::ToolMissing | AsmCheckError::FunctionNotFound { .. }) => {
200                eprintln!("[asm-check] skipping: tool or symbol missing");
201            }
202            Err(e) => panic!("Cumsum kernel asm check failed: {e}"),
203        }
204    }
205}