use crate::typed_ast::{
TypedBlock, TypedElseBranch, TypedExpr, TypedFnDecl, TypedInterpPart, TypedStmt,
};
use std::collections::{HashMap, HashSet};
const SLOT_SIZE: i64 = 8;
const FP_LR_PAIR: i64 = 16;
fn compute_alloc(locals: i64) -> i64 {
-(FP_LR_PAIR + locals) & -16
}
pub(crate) struct FrameLayout {
slots: HashMap<String, i64>,
binding_slots: Vec<i64>,
#[allow(dead_code)]
scratch_count: i64,
scratch_base: i64,
scratch_in_use: i64,
alloc: i64,
}
impl FrameLayout {
pub(crate) fn plan_frame(decl: &TypedFnDecl, fn_names: &HashSet<String>) -> FrameLayout {
let mut slots = HashMap::new();
for (i, param) in decl.params.iter().enumerate() {
let slot = FP_LR_PAIR + SLOT_SIZE * i as i64;
slots.insert(param.name.clone(), slot);
}
let param_count = decl.params.len() as i64;
let binding_base = FP_LR_PAIR + SLOT_SIZE * param_count;
let binding_count = count_binding_slots(&decl.body);
let binding_slots: Vec<i64> = (0..binding_count)
.map(|k| binding_base + SLOT_SIZE * k)
.collect();
let named_slot_count = param_count + binding_count;
let scratch_base = FP_LR_PAIR + SLOT_SIZE * named_slot_count;
let scratch_count = max_spill_depth(&decl.body, fn_names);
let locals = SLOT_SIZE * (named_slot_count + scratch_count);
let alloc = compute_alloc(locals);
FrameLayout {
slots,
binding_slots,
scratch_count,
scratch_base,
scratch_in_use: 0,
alloc,
}
}
pub(crate) fn slot_of(&self, name: &str) -> Option<i64> {
self.slots.get(name).copied()
}
pub(crate) fn binding_slot(&self, index: usize) -> Option<i64> {
self.binding_slots.get(index).copied()
}
pub(crate) fn claim_scratch(&mut self) -> i64 {
let offset = self.scratch_base + SLOT_SIZE * self.scratch_in_use;
self.scratch_in_use += 1;
offset
}
pub(crate) fn release_scratch(&mut self) {
debug_assert!(
self.scratch_in_use > 0,
"release_scratch with no slot claimed"
);
self.scratch_in_use -= 1;
}
pub(crate) fn prologue(&self) -> [String; 2] {
[
format!("stp fp, lr, [sp, {}]!", self.alloc),
"mov fp, sp".to_string(),
]
}
pub(crate) fn epilogue(&self) -> [String; 2] {
[
format!("ldp fp, lr, [sp], {}", -self.alloc),
"ret".to_string(),
]
}
#[allow(dead_code)]
pub(crate) fn scratch_count(&self) -> i64 {
self.scratch_count
}
}
fn count_binding_slots(block: &TypedBlock) -> i64 {
let mut count = 0;
for stmt in &block.stmts {
count += stmt_binding_slots(stmt);
}
count
}
fn stmt_binding_slots(stmt: &TypedStmt) -> i64 {
match stmt {
TypedStmt::Let { .. } => 1,
TypedStmt::If {
then_block,
else_branch,
..
} => {
let mut c = count_binding_slots(then_block);
if let Some(branch) = else_branch {
c += else_branch_binding_slots(branch);
}
c
}
TypedStmt::While { body, .. } => count_binding_slots(body),
TypedStmt::For { body, .. } => 2 + count_binding_slots(body),
TypedStmt::Return { .. }
| TypedStmt::Break { .. }
| TypedStmt::Continue { .. }
| TypedStmt::Defer { .. }
| TypedStmt::Expr { .. } => 0,
}
}
fn else_branch_binding_slots(branch: &TypedElseBranch) -> i64 {
match branch {
TypedElseBranch::Block(b) => count_binding_slots(b),
TypedElseBranch::If(stmt) => stmt_binding_slots(stmt),
}
}
fn max_spill_depth(block: &TypedBlock, fn_names: &HashSet<String>) -> i64 {
let mut max = 0;
for stmt in &block.stmts {
max = max.max(stmt_spill_depth(stmt, fn_names));
}
if let Some(value) = &block.value {
max = max.max(expr_spill_depth(value, fn_names));
}
max
}
fn stmt_spill_depth(stmt: &TypedStmt, fn_names: &HashSet<String>) -> i64 {
match stmt {
TypedStmt::Expr { expr, .. } => expr_spill_depth(expr, fn_names),
TypedStmt::Return { value, .. } => value
.as_ref()
.map(|v| expr_spill_depth(v, fn_names))
.unwrap_or(0),
TypedStmt::Let { init, .. } => expr_spill_depth(init, fn_names),
TypedStmt::If {
cond,
then_block,
else_branch,
..
} => {
let mut d = expr_spill_depth(cond, fn_names).max(max_spill_depth(then_block, fn_names));
if let Some(branch) = else_branch {
d = d.max(else_branch_spill_depth(branch, fn_names));
}
d
}
TypedStmt::While { cond, body, .. } => {
expr_spill_depth(cond, fn_names).max(max_spill_depth(body, fn_names))
}
TypedStmt::For { iter, body, .. } => {
expr_spill_depth(iter, fn_names).max(max_spill_depth(body, fn_names))
}
TypedStmt::Defer { expr, .. } => expr_spill_depth(expr, fn_names),
TypedStmt::Break { .. } | TypedStmt::Continue { .. } => 0,
}
}
fn else_branch_spill_depth(branch: &TypedElseBranch, fn_names: &HashSet<String>) -> i64 {
match branch {
TypedElseBranch::Block(b) => max_spill_depth(b, fn_names),
TypedElseBranch::If(stmt) => stmt_spill_depth(stmt, fn_names),
}
}
fn expr_spill_depth(expr: &TypedExpr, fn_names: &HashSet<String>) -> i64 {
use crate::ast::BinOp;
match expr {
TypedExpr::Binary { op, lhs, rhs, .. } => {
let operands = expr_spill_depth(lhs, fn_names).max(expr_spill_depth(rhs, fn_names));
match op {
BinOp::And | BinOp::Or => operands,
_ => 1 + operands,
}
}
TypedExpr::Call { callee, args, .. } if is_print_callee(callee, fn_names) => {
printf_call_spill_depth(args, fn_names)
}
TypedExpr::Call { callee, args, .. } => {
let mut deepest = expr_spill_depth(callee, fn_names);
for arg in args {
deepest = deepest.max(expr_spill_depth(arg, fn_names));
}
args.len() as i64 + deepest
}
TypedExpr::Unary { operand, .. } => expr_spill_depth(operand, fn_names),
TypedExpr::Paren { inner, .. } => expr_spill_depth(inner, fn_names),
TypedExpr::Block { block, .. } => max_spill_depth(block, fn_names),
TypedExpr::Range { start, end, .. } => {
let s = start
.as_deref()
.map(|e| expr_spill_depth(e, fn_names))
.unwrap_or(0);
let e = end
.as_deref()
.map(|e| expr_spill_depth(e, fn_names))
.unwrap_or(0);
s.max(e)
}
_ => 0,
}
}
fn is_print_callee(callee: &TypedExpr, fn_names: &HashSet<String>) -> bool {
matches!(
callee,
TypedExpr::Ident { name, .. }
if (name == "print" || name == "println") && !fn_names.contains(name)
)
}
fn printf_call_spill_depth(args: &[TypedExpr], fn_names: &HashSet<String>) -> i64 {
let Some(TypedExpr::Interpolation { parts, .. }) = args.first() else {
return 0;
};
let mut hole_count = 0;
let mut deepest = 0;
for part in parts {
if let TypedInterpPart::Expr(expr) = part {
hole_count += 1;
deepest = deepest.max(expr_spill_depth(expr, fn_names));
}
}
hole_count + deepest
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::BinOp;
use crate::effects::EffectSet;
use crate::span::Span;
use crate::types::QalaType;
fn span(n: usize) -> Span {
Span::new(n, n + 1)
}
fn int(value: i64) -> TypedExpr {
TypedExpr::Int {
value,
ty: QalaType::I64,
span: span(0),
}
}
fn fn_decl(param_count: usize, body: crate::typed_ast::TypedBlock) -> TypedFnDecl {
let params = (0..param_count)
.map(|i| crate::typed_ast::TypedParam {
is_self: false,
name: format!("p{i}"),
ty: QalaType::I64,
default: None,
span: span(i),
})
.collect();
TypedFnDecl {
type_name: None,
name: "f".to_string(),
params,
ret_ty: QalaType::I64,
effect: EffectSet::pure(),
body,
span: span(0),
}
}
fn no_fns() -> HashSet<String> {
HashSet::new()
}
fn fn_names_of(typed: &crate::typed_ast::TypedAst) -> HashSet<String> {
typed
.iter()
.filter_map(|item| match item {
crate::typed_ast::TypedItem::Fn(d) => Some(d.name.clone()),
_ => None,
})
.collect()
}
fn empty_body() -> crate::typed_ast::TypedBlock {
crate::typed_ast::TypedBlock {
stmts: vec![],
value: None,
ty: QalaType::Void,
span: span(0),
}
}
#[test]
fn alloc_formula_matches_the_hand_computed_values() {
assert_eq!(compute_alloc(0), -16, "locals=0 must round to -16");
assert_eq!(
compute_alloc(8),
-32,
"locals=8 (24 bytes) must round up to -32"
);
assert_eq!(
compute_alloc(16),
-32,
"locals=16 (32 bytes) must round to -32"
);
assert_eq!(
compute_alloc(24),
-48,
"locals=24 (40 bytes) must round up to -48"
);
}
#[test]
fn alloc_is_always_a_negative_multiple_of_sixteen() {
for locals in (0..=512).step_by(8) {
let a = compute_alloc(locals);
assert!(a <= -16, "alloc must reserve at least the fp/lr pair");
assert_eq!(
a % 16,
0,
"alloc must be a multiple of 16 for locals={locals}"
);
}
}
#[test]
fn a_leaf_function_with_no_params_has_a_minus_sixteen_frame() {
let layout = FrameLayout::plan_frame(&fn_decl(0, empty_body()), &no_fns());
assert_eq!(layout.alloc, -16);
assert_eq!(layout.prologue()[0], "stp fp, lr, [sp, -16]!");
assert_eq!(layout.prologue()[1], "mov fp, sp");
assert_eq!(layout.epilogue()[0], "ldp fp, lr, [sp], 16");
assert_eq!(layout.epilogue()[1], "ret");
}
#[test]
fn three_params_no_scratch_gives_a_minus_thirty_two_frame() {
let body = crate::typed_ast::TypedBlock {
stmts: vec![],
value: Some(Box::new(TypedExpr::Ident {
name: "p0".to_string(),
ty: QalaType::I64,
span: span(0),
})),
ty: QalaType::I64,
span: span(0),
};
let layout = FrameLayout::plan_frame(&fn_decl(3, body), &no_fns());
assert_eq!(layout.scratch_count(), 0, "a bare ident needs no scratch");
assert_eq!(layout.alloc, -48, "3 params (24 bytes) -> alloc -48");
}
#[test]
fn parameter_slots_start_at_sixteen_above_the_saved_pair() {
let layout = FrameLayout::plan_frame(&fn_decl(3, empty_body()), &no_fns());
assert_eq!(layout.slot_of("p0"), Some(16));
assert_eq!(layout.slot_of("p1"), Some(24));
assert_eq!(layout.slot_of("p2"), Some(32));
assert_eq!(layout.slot_of("missing"), None);
}
#[test]
fn max_spill_depth_is_zero_for_a_leaf_expression() {
let body = crate::typed_ast::TypedBlock {
stmts: vec![],
value: Some(Box::new(int(7))),
ty: QalaType::I64,
span: span(0),
};
assert_eq!(max_spill_depth(&body, &no_fns()), 0);
}
#[test]
fn max_spill_depth_counts_a_single_binary_op_as_one() {
let add = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(int(1)),
rhs: Box::new(int(2)),
ty: QalaType::I64,
span: span(0),
};
let body = crate::typed_ast::TypedBlock {
stmts: vec![],
value: Some(Box::new(add)),
ty: QalaType::I64,
span: span(0),
};
assert_eq!(max_spill_depth(&body, &no_fns()), 1);
}
#[test]
fn max_spill_depth_grows_with_right_leaning_nesting() {
let inner = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(int(3)),
rhs: Box::new(int(4)),
ty: QalaType::I64,
span: span(0),
};
let mid = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(int(2)),
rhs: Box::new(inner),
ty: QalaType::I64,
span: span(0),
};
let outer = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(int(1)),
rhs: Box::new(mid),
ty: QalaType::I64,
span: span(0),
};
let body = crate::typed_ast::TypedBlock {
stmts: vec![],
value: Some(Box::new(outer)),
ty: QalaType::I64,
span: span(0),
};
assert_eq!(max_spill_depth(&body, &no_fns()), 3);
}
#[test]
fn left_leaning_nesting_accumulates_spills() {
let a = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(int(1)),
rhs: Box::new(int(2)),
ty: QalaType::I64,
span: span(0),
};
let b = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(a),
rhs: Box::new(int(3)),
ty: QalaType::I64,
span: span(0),
};
let c = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(b),
rhs: Box::new(int(4)),
ty: QalaType::I64,
span: span(0),
};
let body = crate::typed_ast::TypedBlock {
stmts: vec![],
value: Some(Box::new(c)),
ty: QalaType::I64,
span: span(0),
};
assert_eq!(max_spill_depth(&body, &no_fns()), 3);
}
#[test]
fn claim_and_release_scratch_balance() {
let mut layout = FrameLayout::plan_frame(&fn_decl(2, empty_body()), &no_fns());
let first = layout.claim_scratch();
let second = layout.claim_scratch();
assert_eq!(first, 32);
assert_eq!(second, 40);
assert_ne!(first, second, "nested claims must be distinct slots");
layout.release_scratch();
layout.release_scratch();
assert_eq!(layout.claim_scratch(), 32);
}
#[test]
fn scratch_slots_are_counted_into_the_frame() {
let add = TypedExpr::Binary {
op: BinOp::Add,
lhs: Box::new(int(1)),
rhs: Box::new(int(2)),
ty: QalaType::I64,
span: span(0),
};
let body = crate::typed_ast::TypedBlock {
stmts: vec![],
value: Some(Box::new(add)),
ty: QalaType::I64,
span: span(0),
};
let layout = FrameLayout::plan_frame(&fn_decl(1, body), &no_fns());
assert_eq!(layout.scratch_count(), 1);
assert_eq!(layout.alloc, -32);
}
fn fn_from_source(src: &str) -> TypedFnDecl {
let tokens = crate::lexer::Lexer::tokenize(src).expect("lex failed");
let ast = crate::parser::Parser::parse(&tokens).expect("parse failed");
let (typed, terrors, _) = crate::typechecker::check_program(&ast, src);
assert!(terrors.is_empty(), "typecheck errors: {terrors:?}");
typed
.into_iter()
.find_map(|item| match item {
crate::typed_ast::TypedItem::Fn(d) => Some(d),
_ => None,
})
.expect("no function in source")
}
#[test]
fn the_arithmetic_snapshot_function_has_a_minus_sixty_four_frame() {
let decl = fn_from_source("fn arith(a: i64, b: i64) -> i64 { (a + b) * (a - b) / 2 % 7 }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(
layout.scratch_count(),
4,
"the left spine is 4 operators deep"
);
assert_eq!(
layout.alloc, -64,
"2 params + 4 scratch = 48 bytes -> alloc -64"
);
assert_eq!(layout.prologue()[0], "stp fp, lr, [sp, -64]!");
assert_eq!(layout.epilogue()[0], "ldp fp, lr, [sp], 64");
}
#[test]
fn the_comparison_snapshot_function_has_a_minus_forty_eight_frame() {
let decl = fn_from_source("fn cmp(a: i64, b: i64) -> bool { a == b }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(
layout.scratch_count(),
1,
"one comparison -> one scratch slot"
);
assert_eq!(
layout.alloc, -48,
"2 params + 1 scratch = 24 bytes -> alloc -48"
);
}
#[test]
fn the_boolean_snapshot_function_has_a_minus_thirty_two_frame() {
let decl = fn_from_source("fn andor(a: bool, b: bool) -> bool { a && b || !a }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(
layout.scratch_count(),
0,
"short-circuit ops claim no scratch"
);
assert_eq!(
layout.alloc, -32,
"2 params + 0 scratch = 16 bytes -> alloc -32"
);
}
#[test]
fn a_function_with_two_lets_assigns_them_concrete_distinct_slots() {
let decl = fn_from_source("fn f() { let x = 1\nlet y = 2 }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(layout.binding_slot(0), Some(16), "first let -> [fp, 16]");
assert_eq!(layout.binding_slot(1), Some(24), "second let -> [fp, 24]");
assert_eq!(layout.binding_slot(2), None, "only two bindings exist");
assert_eq!(layout.alloc, -32, "2 lets (16 bytes) -> alloc -32");
}
#[test]
fn let_bindings_are_counted_after_the_parameters() {
let decl = fn_from_source("fn f(a: i64) { let x = 1 }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(layout.slot_of("a"), Some(16), "param a -> [fp, 16]");
assert_eq!(layout.binding_slot(0), Some(24), "the let -> [fp, 24]");
assert_eq!(layout.alloc, -32, "1 param + 1 let = 16 bytes -> alloc -32");
}
#[test]
fn a_shadowing_let_gets_its_own_distinct_slot() {
let decl = fn_from_source("fn f() -> i64 { let x = 1\nlet x = 2\nx }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(layout.binding_slot(0), Some(16), "outer x -> [fp, 16]");
assert_eq!(layout.binding_slot(1), Some(24), "shadowing x -> [fp, 24]");
assert_ne!(
layout.binding_slot(0),
layout.binding_slot(1),
"a shadowing let must not reuse the outer slot"
);
}
#[test]
fn lets_nested_inside_an_if_are_counted_into_the_frame() {
let decl =
fn_from_source("fn f(c: bool) { let a = 1\nif c { let b = 2 } else { let d = 3 } }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(
layout.binding_slot(0),
Some(24),
"outer let a, past param c"
);
assert_eq!(layout.binding_slot(1), Some(32), "let b in the then-block");
assert_eq!(layout.binding_slot(2), Some(40), "let d in the else-block");
assert_eq!(layout.alloc, -48, "1 param + 3 nested lets -> alloc -48");
}
#[test]
fn a_for_loop_reserves_two_slots_a_variable_and_an_end_bound() {
let decl = fn_from_source("fn f() { for i in 0..10 { } }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(layout.binding_slot(0), Some(16), "for-variable slot");
assert_eq!(layout.binding_slot(1), Some(24), "for-end-bound slot");
assert_eq!(
layout.binding_slot(2),
None,
"a bare for is exactly two slots"
);
assert_eq!(layout.alloc, -32, "one for loop (16 bytes) -> alloc -32");
}
#[test]
fn a_let_inside_a_for_body_is_counted_after_the_for_slots() {
let decl = fn_from_source("fn f() { for i in 0..3 { let s = i } }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(layout.binding_slot(0), Some(16), "for-variable i");
assert_eq!(layout.binding_slot(1), Some(24), "for-end bound");
assert_eq!(layout.binding_slot(2), Some(32), "the let inside the body");
assert_eq!(layout.alloc, -48, "for (2 slots) + nested let -> alloc -48");
}
#[test]
fn a_function_with_no_calls_reserves_no_scratch() {
let decl = fn_from_source("fn f(a: i64) -> i64 { a }");
let layout = FrameLayout::plan_frame(&decl, &no_fns());
assert_eq!(layout.scratch_count(), 0, "a bare ident needs no scratch");
assert_eq!(layout.alloc, -32, "1 param, no call -> alloc -32");
}
#[test]
fn a_two_param_one_let_function_with_a_three_arg_call_has_concrete_slots() {
let src = "fn add3(a: i64, b: i64, c: i64) -> i64 { a }\n\
fn caller(x: i64, y: i64) -> i64 { let r = add3(x, y, 1)\nr }";
let tokens = crate::lexer::Lexer::tokenize(src).expect("lex failed");
let ast = crate::parser::Parser::parse(&tokens).expect("parse failed");
let (typed, terrors, _) = crate::typechecker::check_program(&ast, src);
assert!(terrors.is_empty(), "typecheck errors: {terrors:?}");
let caller = typed
.iter()
.find_map(|item| match item {
crate::typed_ast::TypedItem::Fn(d) if d.name == "caller" => Some(d.clone()),
_ => None,
})
.expect("no caller function");
let layout = FrameLayout::plan_frame(&caller, &fn_names_of(&typed));
assert_eq!(layout.slot_of("x"), Some(16), "param x -> [fp, 16]");
assert_eq!(layout.slot_of("y"), Some(24), "param y -> [fp, 24]");
assert_eq!(layout.binding_slot(0), Some(32), "let r -> [fp, 32]");
assert_eq!(
layout.scratch_count(),
3,
"a 3-arg call -> three scratch slots"
);
assert_eq!(
layout.alloc, -64,
"2 params + 1 let + 3-arg call -> alloc -64"
);
assert_eq!(layout.prologue()[0], "stp fp, lr, [sp, -64]!");
assert_eq!(layout.epilogue()[0], "ldp fp, lr, [sp], 64");
}
#[test]
fn sibling_calls_in_separate_statements_size_to_the_widest_not_the_sum() {
let src = "fn one(a: i64) -> i64 { a }\n\
fn two(a: i64, b: i64) -> i64 { a }\n\
fn f() { let p = one(1)\nlet q = two(2, 3) }";
let tokens = crate::lexer::Lexer::tokenize(src).expect("lex failed");
let ast = crate::parser::Parser::parse(&tokens).expect("parse failed");
let (typed, terrors, _) = crate::typechecker::check_program(&ast, src);
assert!(terrors.is_empty(), "typecheck errors: {terrors:?}");
let f = typed
.iter()
.find_map(|item| match item {
crate::typed_ast::TypedItem::Fn(d) if d.name == "f" => Some(d.clone()),
_ => None,
})
.expect("no f function");
let layout = FrameLayout::plan_frame(&f, &fn_names_of(&typed));
assert_eq!(layout.scratch_count(), 2, "the widest call passes two args");
assert_eq!(layout.alloc, -48, "two lets + a 2-arg call -> alloc -48");
}
#[test]
fn a_nested_call_argument_adds_the_inner_calls_depth_to_the_outer_arity() {
let src = "fn inner(a: i64, b: i64, c: i64) -> i64 { a }\n\
fn outer(x: i64) -> i64 { x }\n\
fn f() { let r = outer(inner(1, 2, 3)) }";
let tokens = crate::lexer::Lexer::tokenize(src).expect("lex failed");
let ast = crate::parser::Parser::parse(&tokens).expect("parse failed");
let (typed, terrors, _) = crate::typechecker::check_program(&ast, src);
assert!(terrors.is_empty(), "typecheck errors: {terrors:?}");
let f = typed
.iter()
.find_map(|item| match item {
crate::typed_ast::TypedItem::Fn(d) if d.name == "f" => Some(d.clone()),
_ => None,
})
.expect("no f function");
let layout = FrameLayout::plan_frame(&f, &fn_names_of(&typed));
assert_eq!(
layout.scratch_count(),
4,
"outer arity (1) + inner call depth (3) = 4 scratch slots"
);
assert_eq!(
layout.alloc, -64,
"1 let + a nested 3-arg call -> alloc -64"
);
}
}