use crate::errors::QalaError;
use crate::span::Span;
use crate::typed_ast::{TypedBlock, TypedElseBranch, TypedExpr, TypedStmt};
use super::Arm64Backend;
pub(crate) struct LoopLabels {
pub(crate) continue_target: String,
pub(crate) end: String,
}
impl Arm64Backend {
pub(super) fn compile_block(&mut self, block: &TypedBlock) -> Result<(), QalaError> {
self.scopes.push(Vec::new());
for stmt in &block.stmts {
if let Err(e) = self.compile_stmt(stmt) {
self.scopes.pop();
return Err(e);
}
}
if let Some(value) = &block.value
&& let Err(e) = self.compile_expr(value)
{
self.scopes.pop();
return Err(e);
}
self.scopes.pop();
Ok(())
}
pub(super) fn compile_stmt(&mut self, stmt: &TypedStmt) -> Result<(), QalaError> {
match stmt {
TypedStmt::Let {
name, init, span, ..
} => self.compile_let(name, init, *span),
TypedStmt::If {
cond,
then_block,
else_branch,
..
} => self.compile_if(cond, then_block, else_branch.as_ref()),
TypedStmt::While { cond, body, .. } => self.compile_while(cond, body),
TypedStmt::For {
var,
iter,
body,
span,
..
} => self.compile_for(var, iter, body, *span),
TypedStmt::Return { value, span } => {
if let Some(e) = value {
self.compile_expr(e)?;
}
let fn_name = self.current_fn.clone().ok_or_else(|| QalaError::Type {
span: *span,
message: "arm64 backend: return outside a function".to_string(),
})?;
let label = self.epilogue_label(&fn_name);
self.asm.emit_insn(&format!("b {label}"));
Ok(())
}
TypedStmt::Break { span } => {
let target = self
.loops
.last()
.map(|l| l.end.clone())
.ok_or_else(|| self.unsupported_stmt(*span, "break outside a loop"))?;
self.asm.emit_insn(&format!("b {target}"));
Ok(())
}
TypedStmt::Continue { span } => {
let target = self
.loops
.last()
.map(|l| l.continue_target.clone())
.ok_or_else(|| self.unsupported_stmt(*span, "continue outside a loop"))?;
self.asm.emit_insn(&format!("b {target}"));
Ok(())
}
TypedStmt::Expr { expr, .. } => self.compile_expr(expr),
TypedStmt::Defer { span, .. } => {
Err(self.unsupported_stmt(*span, "the defer statement"))
}
}
}
fn compile_let(&mut self, name: &str, init: &TypedExpr, span: Span) -> Result<(), QalaError> {
self.compile_expr(init)?;
let slot = self.next_binding_slot(span)?;
self.asm
.emit_insn_commented(&format!("str x0, [fp, {slot}]"), name);
self.bind_name(name, slot);
Ok(())
}
fn compile_if(
&mut self,
cond: &TypedExpr,
then_block: &TypedBlock,
else_branch: Option<&TypedElseBranch>,
) -> Result<(), QalaError> {
let end = self.labels.fresh("if_end");
match else_branch {
Some(branch) => {
let else_label = self.labels.fresh("if_else");
self.compile_expr(cond)?;
self.asm.emit_insn(&format!("cbz x0, {else_label}"));
self.compile_block(then_block)?;
self.asm.emit_insn(&format!("b {end}"));
self.asm.emit_label(&else_label);
match branch {
TypedElseBranch::Block(b) => self.compile_block(b)?,
TypedElseBranch::If(boxed) => self.compile_stmt(boxed)?,
}
self.asm.emit_label(&end);
}
None => {
self.compile_expr(cond)?;
self.asm.emit_insn(&format!("cbz x0, {end}"));
self.compile_block(then_block)?;
self.asm.emit_label(&end);
}
}
Ok(())
}
fn compile_while(&mut self, cond: &TypedExpr, body: &TypedBlock) -> Result<(), QalaError> {
let body_label = self.labels.fresh("while_body");
let test_label = self.labels.fresh("while_test");
let end_label = self.labels.fresh("while_end");
self.asm.emit_insn(&format!("b {test_label}"));
self.asm.emit_label(&body_label);
self.loops.push(LoopLabels {
continue_target: test_label.clone(),
end: end_label.clone(),
});
let body_result = self.compile_block(body);
self.loops.pop();
body_result?;
self.asm.emit_label(&test_label);
self.compile_expr(cond)?;
self.asm.emit_insn(&format!("cbnz x0, {body_label}"));
self.asm.emit_label(&end_label);
Ok(())
}
fn compile_for(
&mut self,
var: &str,
iter: &TypedExpr,
body: &TypedBlock,
span: Span,
) -> Result<(), QalaError> {
let (start, end, inclusive) = match iter {
TypedExpr::Range {
start,
end,
inclusive,
..
} => (start, end, *inclusive),
_ => {
return Err(self.unsupported_stmt(span, "for over a non-range iterable"));
}
};
let start = start
.as_ref()
.ok_or_else(|| self.unsupported_stmt(span, "a for range with no start bound"))?;
let end = end
.as_ref()
.ok_or_else(|| self.unsupported_stmt(span, "a for range with no end bound"))?;
let var_slot = self.next_binding_slot(span)?;
let end_slot = self.next_binding_slot(span)?;
self.compile_expr(start)?;
self.asm
.emit_insn_commented(&format!("str x0, [fp, {var_slot}]"), var);
self.compile_expr(end)?;
self.asm
.emit_insn_commented(&format!("str x0, [fp, {end_slot}]"), "for end");
let body_label = self.labels.fresh("for_body");
let incr_label = self.labels.fresh("for_incr");
let test_label = self.labels.fresh("for_test");
let end_label = self.labels.fresh("for_end");
self.bind_name(var, var_slot);
self.asm.emit_insn(&format!("b {test_label}"));
self.asm.emit_label(&body_label);
self.loops.push(LoopLabels {
continue_target: incr_label.clone(),
end: end_label.clone(),
});
let body_result = self.compile_block(body);
self.loops.pop();
body_result?;
self.asm.emit_label(&incr_label);
self.asm.emit_insn(&format!("ldr x0, [fp, {var_slot}]"));
self.asm.emit_insn("add x0, x0, 1");
self.asm.emit_insn(&format!("str x0, [fp, {var_slot}]"));
self.asm.emit_label(&test_label);
self.asm.emit_insn(&format!("ldr x0, [fp, {var_slot}]"));
self.asm.emit_insn(&format!("ldr x9, [fp, {end_slot}]"));
self.asm.emit_insn("cmp x0, x9");
let branch = if inclusive { "b.le" } else { "b.lt" };
self.asm.emit_insn(&format!("{branch} {body_label}"));
self.asm.emit_label(&end_label);
Ok(())
}
fn next_binding_slot(&mut self, span: Span) -> Result<i64, QalaError> {
let index = self.binding_cursor;
let slot = self
.frame()
.binding_slot(index)
.ok_or_else(|| QalaError::Type {
span,
message: format!("arm64 backend: binding occurrence {index} has no slot"),
})?;
self.binding_cursor += 1;
Ok(slot)
}
fn bind_name(&mut self, name: &str, slot: i64) {
if let Some(scope) = self.scopes.last_mut() {
scope.push((name.to_string(), slot));
}
}
pub(super) fn resolve_name(&self, name: &str) -> Option<i64> {
for scope in self.scopes.iter().rev() {
for (n, slot) in scope.iter().rev() {
if n == name {
return Some(*slot);
}
}
}
self.frame().slot_of(name)
}
pub(super) fn unsupported_stmt(&self, span: Span, what: &str) -> QalaError {
QalaError::Type {
span,
message: format!("the arm64 backend does not yet support {what}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer::Lexer;
use crate::parser::Parser;
use crate::typechecker::check_program;
use crate::typed_ast::TypedAst;
const SNAPSHOT_PROGRAMS: [(&str, &str); 4] = [
(
"arm64_let.s",
"fn let_demo(a: i64) -> i64 {\n\
\x20 let x = a + 1\n\
\x20 let y = x * 2\n\
\x20 let x = y - a\n\
\x20 x\n\
}\n",
),
(
"arm64_if_else.s",
"fn sign(n: i64) -> i64 {\n\
\x20 if n == 0 { return 0 }\n\
\x20 if n > 0 { return 1 } else { return -1 }\n\
\x20 0\n\
}\n\
fn grade(s: i64) -> i64 {\n\
\x20 if s >= 90 { return 4 }\n\
\x20 else if s >= 80 { return 3 }\n\
\x20 else { return 2 }\n\
\x20 0\n\
}\n",
),
(
"arm64_while.s",
"fn loop_demo(c: bool) {\n\
\x20 while c {\n\
\x20 if c { break }\n\
\x20 continue\n\
\x20 }\n\
}\n",
),
(
"arm64_for.s",
"fn for_demo(n: i64) {\n\
\x20 for i in 0..n { }\n\
\x20 for j in 0..=n { }\n\
}\n",
),
];
fn read_snapshot(name: &str) -> String {
let path = format!("{}/tests/snapshots/{name}", env!("CARGO_MANIFEST_DIR"));
std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("read {path}: {e}"))
.replace("\r\n", "\n")
}
fn typecheck(src: &str) -> TypedAst {
let tokens = Lexer::tokenize(src).expect("lex failed");
let ast = Parser::parse(&tokens).expect("parse failed");
let (typed, terrors, _) = check_program(&ast, src);
assert!(terrors.is_empty(), "typecheck errors: {terrors:?}");
typed
}
fn compile_ok(src: &str) -> String {
let typed = typecheck(src);
super::super::compile_arm64(&typed, src).unwrap_or_else(|e| panic!("arm64 errors: {e:?}"))
}
fn snapshot_src(name: &str) -> &'static str {
SNAPSHOT_PROGRAMS
.iter()
.find(|(n, _)| *n == name)
.map(|(_, src)| *src)
.unwrap_or_else(|| panic!("no snapshot program named {name}"))
}
#[test]
fn let_bindings_match_the_snapshot() {
let emitted = compile_ok(snapshot_src("arm64_let.s"));
assert_eq!(
emitted,
read_snapshot("arm64_let.s"),
"arm64 let emission drifted from snapshot"
);
}
#[test]
fn if_else_matches_the_snapshot() {
let emitted = compile_ok(snapshot_src("arm64_if_else.s"));
assert_eq!(
emitted,
read_snapshot("arm64_if_else.s"),
"arm64 if/else emission drifted from snapshot"
);
}
#[test]
fn while_loops_match_the_snapshot() {
let emitted = compile_ok(snapshot_src("arm64_while.s"));
assert_eq!(
emitted,
read_snapshot("arm64_while.s"),
"arm64 while emission drifted from snapshot"
);
}
#[test]
fn for_loops_match_the_snapshot() {
let emitted = compile_ok(snapshot_src("arm64_for.s"));
assert_eq!(
emitted,
read_snapshot("arm64_for.s"),
"arm64 for emission drifted from snapshot"
);
}
#[test]
fn a_let_stores_x0_into_its_stack_slot() {
let out = compile_ok("fn f() -> i64 { let x = 7\nx }");
assert!(out.contains("mov x0, 7"), "{out}");
assert!(
out.contains("str x0, [fp,"),
"missing the let store: {out}"
);
}
#[test]
fn an_ident_use_of_a_let_loads_its_slot() {
let out = compile_ok("fn f() -> i64 { let x = 7\nx }");
let store = out
.lines()
.find(|l| l.contains("str x0, [fp,"))
.expect("no store");
let load = out
.lines()
.find(|l| l.contains("ldr x0, [fp,"))
.expect("no load");
let store_slot = &store[store.find("[fp,").unwrap()..store.find(']').unwrap()];
let load_slot = &load[load.find("[fp,").unwrap()..load.find(']').unwrap()];
assert_eq!(
store_slot, load_slot,
"let store and ident load slots differ"
);
}
#[test]
fn a_shadowing_let_resolves_the_ident_to_the_newest_slot() {
let out = compile_ok("fn f() -> i64 { let x = 1\nlet x = 2\nx }");
let stores: Vec<&str> = out
.lines()
.filter(|l| l.contains("str x0, [fp,"))
.collect();
assert_eq!(stores.len(), 2, "two lets -> two stores: {out}");
let second_store = stores[1];
let second_slot =
&second_store[second_store.find("[fp,").unwrap()..second_store.find(']').unwrap()];
let load = out
.lines()
.rev()
.find(|l| l.contains("ldr x0, [fp,"))
.expect("no load");
let load_slot = &load[load.find("[fp,").unwrap()..load.find(']').unwrap()];
assert_eq!(
load_slot, second_slot,
"ident must resolve to the shadowing let"
);
}
#[test]
fn an_if_without_else_emits_one_cbz_and_an_end_label() {
let out = compile_ok("fn f(c: bool) -> i64 { if c { return 1 }\n0 }");
assert!(out.contains("cbz x0, .Lif_end_"), "missing cbz: {out}");
assert!(out.contains(".Lif_end_"), "missing end label: {out}");
assert!(!out.contains(".Lif_else_"), "unexpected else label: {out}");
}
#[test]
fn an_if_else_emits_both_the_else_and_end_labels() {
let out = compile_ok("fn f(c: bool) -> i64 { if c { return 1 } else { return 2 }\n0 }");
assert!(
out.contains("cbz x0, .Lif_else_"),
"missing cbz to else: {out}"
);
assert!(out.contains(".Lif_else_"), "missing else label: {out}");
assert!(out.contains(".Lif_end_"), "missing end label: {out}");
assert!(
out.contains("b .Lif_end_"),
"missing jump over else: {out}"
);
}
#[test]
fn an_else_if_chain_emits_nested_labels() {
let out = compile_ok(
"fn f(a: bool, b: bool) -> i64 { \
if a { return 1 } else if b { return 2 } else { return 3 }\n0 }",
);
let else_labels: Vec<&str> = out
.lines()
.filter(|l| l.trim_start().starts_with(".Lif_else_"))
.collect();
assert_eq!(
else_labels.len(),
2,
"an else-if chain has two else labels: {out}"
);
}
#[test]
fn a_return_in_a_block_branches_to_the_epilogue() {
let out = compile_ok("fn f(c: bool) -> i64 { if c { return 9 }\n0 }");
assert!(out.contains("b .Lf_epilogue"), "{out}");
}
#[test]
fn a_let_inside_an_if_gets_its_own_slot() {
let out = compile_ok("fn f(c: bool) -> i64 { let a = 1\nif c { let b = 2\nreturn b }\na }");
let stores: Vec<&str> = out
.lines()
.filter(|l| l.contains("str x0, [fp,"))
.collect();
assert!(stores.len() >= 2, "outer and inner let both store: {out}");
}
#[test]
fn defer_is_rejected_cleanly() {
let stmt = TypedStmt::Defer {
expr: TypedExpr::Int {
value: 1,
ty: crate::types::QalaType::I64,
span: Span::new(0, 1),
},
span: Span::new(0, 5),
};
let mut backend = Arm64Backend::new("");
let err = backend
.compile_stmt(&stmt)
.expect_err("a defer must be rejected");
match err {
QalaError::Type { message, .. } => {
assert!(message.contains("defer"), "message: {message}");
}
other => panic!("expected QalaError::Type, got {other:?}"),
}
}
#[test]
fn a_break_outside_a_loop_is_a_clean_error_not_a_panic() {
let stmt = TypedStmt::Break {
span: Span::new(0, 1),
};
let mut backend = Arm64Backend::new("");
let err = backend
.compile_stmt(&stmt)
.expect_err("a break outside a loop must be rejected");
match err {
QalaError::Type { message, .. } => {
assert!(message.contains("break"), "message: {message}");
}
other => panic!("expected QalaError::Type, got {other:?}"),
}
}
#[test]
fn a_while_emits_the_test_at_bottom_shape() {
let out = compile_ok("fn f(c: bool) { while c { } }");
assert!(
out.contains("b .Lwhile_test_"),
"missing jump to test: {out}"
);
assert!(
out.lines()
.any(|l| l.trim_start().starts_with(".Lwhile_body_")),
"missing body label: {out}"
);
assert!(
out.lines()
.any(|l| l.trim_start().starts_with(".Lwhile_test_")),
"missing test label: {out}"
);
assert!(
out.contains("cbnz x0, .Lwhile_body_"),
"missing cbnz: {out}"
);
assert!(
out.lines()
.any(|l| l.trim_start().starts_with(".Lwhile_end_")),
"missing end label: {out}"
);
}
#[test]
fn a_while_break_branches_to_the_end_label() {
let out = compile_ok("fn f(c: bool) { while c { break } }");
let end_target = out
.lines()
.find_map(|l| {
let t = l.trim();
t.strip_prefix(".Lwhile_end_")
.map(|n| format!(".Lwhile_end_{}", n.trim_end_matches(':')))
})
.expect("no while end label");
assert!(
out.contains(&format!("b {end_target}")),
"break must branch to {end_target}: {out}"
);
}
#[test]
fn a_while_continue_branches_to_the_test_label() {
let out = compile_ok("fn f(c: bool) { while c { continue } }");
let test_target = out
.lines()
.find_map(|l| {
let t = l.trim();
t.strip_prefix(".Lwhile_test_")
.map(|n| format!(".Lwhile_test_{}", n.trim_end_matches(':')))
})
.expect("no while test label");
let branches = out
.lines()
.filter(|l| l.contains(&format!("b {test_target}")))
.count();
assert_eq!(
branches, 2,
"loop entry + continue both branch to the test: {out}"
);
}
#[test]
fn a_for_over_an_exclusive_range_emits_b_lt() {
let out = compile_ok("fn f(n: i64) { for i in 0..n { } }");
assert!(
out.contains("b .Lfor_test_"),
"missing jump to test: {out}"
);
assert!(
out.lines()
.any(|l| l.trim_start().starts_with(".Lfor_incr_")),
"missing increment label: {out}"
);
assert!(
out.contains("add x0, x0, 1"),
"missing the increment: {out}"
);
assert!(out.contains("cmp x0, x9"), "missing the compare: {out}");
assert!(
out.contains("b.lt .Lfor_body_"),
"exclusive range -> b.lt: {out}"
);
assert!(
!out.contains("b.le"),
"an exclusive range must not use b.le: {out}"
);
}
#[test]
fn a_for_over_an_inclusive_range_emits_b_le() {
let out = compile_ok("fn f(n: i64) { for i in 0..=n { } }");
assert!(
out.contains("b.le .Lfor_body_"),
"inclusive range -> b.le: {out}"
);
assert!(
!out.contains("b.lt"),
"an inclusive range must not use b.lt: {out}"
);
}
#[test]
fn a_for_continue_branches_to_the_increment_label() {
let out = compile_ok("fn f(n: i64) { for i in 0..n { continue } }");
let incr_target = out
.lines()
.find_map(|l| {
let t = l.trim();
t.strip_prefix(".Lfor_incr_")
.map(|n| format!(".Lfor_incr_{}", n.trim_end_matches(':')))
})
.expect("no for increment label");
assert!(
out.contains(&format!("b {incr_target}")),
"continue must branch to {incr_target}: {out}"
);
}
#[test]
fn a_for_evaluates_its_end_bound_once_into_a_slot() {
let out = compile_ok("fn f(n: i64) { for i in 0..n { } }");
let end_stores = out
.lines()
.filter(|l| l.contains("str x0, [fp,") && l.contains("// for end"))
.count();
assert_eq!(
end_stores, 1,
"the for end bound is stored exactly once: {out}"
);
}
#[test]
fn a_for_over_a_non_range_iterable_is_rejected_cleanly() {
let typed = typecheck("fn f() { let arr = [1, 2, 3]\nfor x in arr { } }");
let err =
super::super::compile_arm64(&typed, "").expect_err("a non-range for must be rejected");
match &err[0] {
QalaError::Type { message, .. } => {
assert!(
message.contains("non-range") || message.contains("array"),
"message: {message}"
);
}
other => panic!("expected QalaError::Type, got {other:?}"),
}
}
#[test]
fn a_for_range_with_a_missing_bound_is_rejected_cleanly() {
use crate::types::QalaType;
let stmt = TypedStmt::For {
var: "i".to_string(),
var_ty: QalaType::I64,
iter: TypedExpr::Range {
start: Some(Box::new(TypedExpr::Int {
value: 0,
ty: QalaType::I64,
span: Span::new(0, 1),
})),
end: None,
inclusive: false,
ty: QalaType::Array(Box::new(QalaType::I64), None),
span: Span::new(0, 4),
},
body: TypedBlock {
stmts: vec![],
value: None,
ty: QalaType::Void,
span: Span::new(5, 6),
},
span: Span::new(0, 7),
};
let mut backend = Arm64Backend::new("");
let err = backend
.compile_stmt(&stmt)
.expect_err("a for with a missing bound must be rejected");
match err {
QalaError::Type { message, .. } => {
assert!(message.contains("bound"), "message: {message}");
}
other => panic!("expected QalaError::Type, got {other:?}"),
}
}
}