use super::{CodeGen, CodeGenError, mangle_name};
use crate::ast::{Statement, WordDef};
use crate::types::{StackType, Type};
use std::fmt::Write as _;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RegisterType {
I64,
Double,
}
impl RegisterType {
pub fn from_type(ty: &Type) -> Option<Self> {
match ty {
Type::Int | Type::Bool => Some(RegisterType::I64),
Type::Float => Some(RegisterType::Double),
_ => None,
}
}
pub fn llvm_type(&self) -> &'static str {
match self {
RegisterType::I64 => "i64",
RegisterType::Double => "double",
}
}
}
#[derive(Debug, Clone)]
pub struct SpecSignature {
pub inputs: Vec<RegisterType>,
pub outputs: Vec<RegisterType>,
}
impl SpecSignature {
pub fn suffix(&self) -> String {
if self.inputs.len() == 1 && self.outputs.len() == 1 {
match (self.inputs[0], self.outputs[0]) {
(RegisterType::I64, RegisterType::I64) => "_i64".to_string(),
(RegisterType::Double, RegisterType::Double) => "_f64".to_string(),
(RegisterType::I64, RegisterType::Double) => "_i64_to_f64".to_string(),
(RegisterType::Double, RegisterType::I64) => "_f64_to_i64".to_string(),
}
} else {
let mut suffix = String::new();
for ty in &self.inputs {
suffix.push('_');
suffix.push_str(match ty {
RegisterType::I64 => "i",
RegisterType::Double => "f",
});
}
suffix.push_str("_to");
for ty in &self.outputs {
suffix.push('_');
suffix.push_str(match ty {
RegisterType::I64 => "i",
RegisterType::Double => "f",
});
}
suffix
}
}
pub fn is_direct_call(&self) -> bool {
self.outputs.len() == 1
}
pub fn llvm_return_type(&self) -> String {
if self.outputs.len() == 1 {
self.outputs[0].llvm_type().to_string()
} else {
let types: Vec<_> = self.outputs.iter().map(|t| t.llvm_type()).collect();
format!("{{ {} }}", types.join(", "))
}
}
}
#[derive(Debug, Clone)]
pub struct RegisterContext {
pub values: Vec<(String, RegisterType)>,
}
impl RegisterContext {
pub fn new() -> Self {
Self { values: Vec::new() }
}
pub fn from_params(params: &[(String, RegisterType)]) -> Self {
Self {
values: params.to_vec(),
}
}
pub fn push(&mut self, ssa_var: String, ty: RegisterType) {
self.values.push((ssa_var, ty));
}
pub fn pop(&mut self) -> Option<(String, RegisterType)> {
self.values.pop()
}
#[cfg_attr(not(test), allow(dead_code))]
pub fn len(&self) -> usize {
self.values.len()
}
pub fn dup(&mut self) {
if let Some((ssa, ty)) = self.values.last().cloned() {
self.values.push((ssa, ty));
}
}
pub fn drop(&mut self) {
self.values.pop();
}
pub fn swap(&mut self) {
let len = self.values.len();
if len >= 2 {
self.values.swap(len - 1, len - 2);
}
}
pub fn over(&mut self) {
let len = self.values.len();
if len >= 2 {
let a = self.values[len - 2].clone();
self.values.push(a);
}
}
pub fn rot(&mut self) {
let len = self.values.len();
if len >= 3 {
let a = self.values.remove(len - 3);
self.values.push(a);
}
}
}
impl Default for RegisterContext {
fn default() -> Self {
Self::new()
}
}
const SPECIALIZABLE_OPS: &[&str] = &[
"i.+",
"i.add",
"i.-",
"i.subtract",
"i.*",
"i.multiply",
"i./",
"i.divide",
"i.%",
"i.mod",
"band",
"bor",
"bxor",
"bnot",
"shl",
"shr",
"popcount",
"clz",
"ctz",
"int->float",
"float->int",
"and",
"or",
"not",
"i.<",
"i.lt",
"i.>",
"i.gt",
"i.<=",
"i.lte",
"i.>=",
"i.gte",
"i.=",
"i.eq",
"i.<>",
"i.neq",
"f.+",
"f.add",
"f.-",
"f.subtract",
"f.*",
"f.multiply",
"f./",
"f.divide",
"f.<",
"f.lt",
"f.>",
"f.gt",
"f.<=",
"f.lte",
"f.>=",
"f.gte",
"f.=",
"f.eq",
"f.<>",
"f.neq",
"dup",
"drop",
"swap",
"over",
"rot",
"nip",
"tuck",
"pick",
"roll",
];
impl CodeGen {
pub fn can_specialize(&self, word: &WordDef) -> Option<SpecSignature> {
let effect = word.effect.as_ref()?;
if !effect.is_pure() {
return None;
}
let inputs = Self::extract_register_types(&effect.inputs)?;
let outputs = Self::extract_register_types(&effect.outputs)?;
if inputs.is_empty() && outputs.is_empty() {
return None;
}
if outputs.is_empty() {
return None;
}
if !self.is_body_specializable(&word.body, &word.name) {
return None;
}
Some(SpecSignature { inputs, outputs })
}
fn extract_register_types(stack: &StackType) -> Option<Vec<RegisterType>> {
let mut types = Vec::new();
let mut current = stack;
loop {
match current {
StackType::Empty => break,
StackType::RowVar(_) => {
break;
}
StackType::Cons { rest, top } => {
let reg_ty = RegisterType::from_type(top)?;
types.push(reg_ty);
current = rest;
}
}
}
types.reverse();
Some(types)
}
fn is_body_specializable(&self, body: &[Statement], word_name: &str) -> bool {
let mut prev_was_int_literal = false;
for stmt in body {
if !self.is_statement_specializable(stmt, word_name, prev_was_int_literal) {
return false;
}
prev_was_int_literal = matches!(stmt, Statement::IntLiteral(_));
}
true
}
fn is_statement_specializable(
&self,
stmt: &Statement,
word_name: &str,
prev_was_int_literal: bool,
) -> bool {
match stmt {
Statement::IntLiteral(_) => true,
Statement::FloatLiteral(_) => true,
Statement::BoolLiteral(_) => true,
Statement::StringLiteral(_) => false,
Statement::Symbol(_) => false,
Statement::Quotation { .. } => false,
Statement::Match { .. } => false,
Statement::WordCall { name, .. } => {
if name == word_name {
return true;
}
if (name == "pick" || name == "roll") && !prev_was_int_literal {
return false;
}
if SPECIALIZABLE_OPS.contains(&name.as_str()) {
return true;
}
if self.specialized_words.contains_key(name) {
return true;
}
false
}
Statement::If {
then_branch,
else_branch,
span: _,
} => {
if !self.is_body_specializable(then_branch, word_name) {
return false;
}
if let Some(else_stmts) = else_branch
&& !self.is_body_specializable(else_stmts, word_name)
{
return false;
}
true
}
}
}
pub fn codegen_specialized_word(
&mut self,
word: &WordDef,
sig: &SpecSignature,
) -> Result<(), CodeGenError> {
let base_name = format!("seq_{}", mangle_name(&word.name));
let spec_name = format!("{}{}", base_name, sig.suffix());
let return_type = if sig.outputs.len() == 1 {
sig.outputs[0].llvm_type().to_string()
} else {
let types: Vec<_> = sig.outputs.iter().map(|t| t.llvm_type()).collect();
format!("{{ {} }}", types.join(", "))
};
let params: Vec<String> = sig
.inputs
.iter()
.enumerate()
.map(|(i, ty)| format!("{} %arg{}", ty.llvm_type(), i))
.collect();
writeln!(
&mut self.output,
"define {} @{}({}) {{",
return_type,
spec_name,
params.join(", ")
)?;
writeln!(&mut self.output, "entry:")?;
let initial_params: Vec<(String, RegisterType)> = sig
.inputs
.iter()
.enumerate()
.map(|(i, ty)| (format!("arg{}", i), *ty))
.collect();
let mut ctx = RegisterContext::from_params(&initial_params);
let body_len = word.body.len();
let mut prev_int_literal: Option<i64> = None;
for (i, stmt) in word.body.iter().enumerate() {
let is_last = i == body_len - 1;
self.codegen_specialized_statement(
&mut ctx,
stmt,
&word.name,
sig,
is_last,
&mut prev_int_literal,
)?;
}
writeln!(&mut self.output, "}}")?;
writeln!(&mut self.output)?;
self.specialized_words
.insert(word.name.clone(), sig.clone());
Ok(())
}
fn codegen_specialized_statement(
&mut self,
ctx: &mut RegisterContext,
stmt: &Statement,
word_name: &str,
sig: &SpecSignature,
is_last: bool,
prev_int_literal: &mut Option<i64>,
) -> Result<(), CodeGenError> {
let prev_int = *prev_int_literal;
*prev_int_literal = None;
match stmt {
Statement::IntLiteral(n) => {
let var = self.fresh_temp();
writeln!(&mut self.output, " %{} = add i64 0, {}", var, n)?;
ctx.push(var, RegisterType::I64);
*prev_int_literal = Some(*n); }
Statement::FloatLiteral(f) => {
let var = self.fresh_temp();
let bits = f.to_bits();
writeln!(
&mut self.output,
" %{} = bitcast i64 {} to double",
var, bits
)?;
ctx.push(var, RegisterType::Double);
}
Statement::BoolLiteral(b) => {
let var = self.fresh_temp();
let val = if *b { 1 } else { 0 };
writeln!(&mut self.output, " %{} = add i64 0, {}", var, val)?;
ctx.push(var, RegisterType::I64);
}
Statement::WordCall { name, .. } => {
self.codegen_specialized_word_call(ctx, name, word_name, sig, is_last, prev_int)?;
}
Statement::If {
then_branch,
else_branch,
span: _,
} => {
self.codegen_specialized_if(
ctx,
then_branch,
else_branch.as_ref(),
word_name,
sig,
is_last,
)?;
}
Statement::StringLiteral(_)
| Statement::Symbol(_)
| Statement::Quotation { .. }
| Statement::Match { .. } => {
return Err(CodeGenError::Logic(format!(
"Non-specializable statement in specialized word: {:?}",
stmt
)));
}
}
let already_returns = match stmt {
Statement::If { .. } => true,
Statement::WordCall { name, .. } if name == word_name => true,
_ => false,
};
if is_last && !already_returns {
self.emit_specialized_return(ctx, sig)?;
}
Ok(())
}
fn codegen_specialized_word_call(
&mut self,
ctx: &mut RegisterContext,
name: &str,
word_name: &str,
sig: &SpecSignature,
is_last: bool,
prev_int: Option<i64>,
) -> Result<(), CodeGenError> {
match name {
"dup" => ctx.dup(),
"drop" => ctx.drop(),
"swap" => ctx.swap(),
"over" => ctx.over(),
"rot" => ctx.rot(),
"nip" => {
ctx.swap();
ctx.drop();
}
"tuck" => {
ctx.dup();
let b = ctx.pop().unwrap();
let b2 = ctx.pop().unwrap();
let a = ctx.pop().unwrap();
ctx.push(b.0, b.1);
ctx.push(a.0, a.1);
ctx.push(b2.0, b2.1);
}
"pick" => {
let n = prev_int.ok_or_else(|| {
CodeGenError::Logic("pick requires constant N in specialized mode".to_string())
})?;
if n < 0 {
return Err(CodeGenError::Logic(format!(
"pick requires non-negative N, got {}",
n
)));
}
let n = n as usize;
ctx.pop();
let len = ctx.values.len();
if n >= len {
return Err(CodeGenError::Logic(format!(
"pick {} but only {} values in context",
n, len
)));
}
let (var, ty) = ctx.values[len - 1 - n].clone();
ctx.push(var, ty);
}
"roll" => {
let n = prev_int.ok_or_else(|| {
CodeGenError::Logic("roll requires constant N in specialized mode".to_string())
})?;
if n < 0 {
return Err(CodeGenError::Logic(format!(
"roll requires non-negative N, got {}",
n
)));
}
let n = n as usize;
ctx.pop();
let len = ctx.values.len();
if n >= len {
return Err(CodeGenError::Logic(format!(
"roll {} but only {} values in context",
n, len
)));
}
if n > 0 {
let val = ctx.values.remove(len - 1 - n);
ctx.values.push(val);
}
}
"i.+" | "i.add" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = add i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"i.-" | "i.subtract" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = sub i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"i.*" | "i.multiply" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = mul i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"i./" | "i.divide" => {
self.emit_specialized_safe_div(ctx, "sdiv")?;
}
"i.%" | "i.mod" => {
self.emit_specialized_safe_div(ctx, "srem")?;
}
"band" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"bor" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"bxor" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = xor i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"bnot" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = xor i64 %{}, -1", result, a)?;
ctx.push(result, RegisterType::I64);
}
"shl" => {
self.emit_specialized_safe_shift(ctx, true)?;
}
"shr" => {
self.emit_specialized_safe_shift(ctx, false)?;
}
"popcount" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = call i64 @llvm.ctpop.i64(i64 %{})",
result, a
)?;
ctx.push(result, RegisterType::I64);
}
"clz" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = call i64 @llvm.ctlz.i64(i64 %{}, i1 false)",
result, a
)?;
ctx.push(result, RegisterType::I64);
}
"ctz" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = call i64 @llvm.cttz.i64(i64 %{}, i1 false)",
result, a
)?;
ctx.push(result, RegisterType::I64);
}
"int->float" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = sitofp i64 %{} to double",
result, a
)?;
ctx.push(result, RegisterType::Double);
}
"float->int" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = fptosi double %{} to i64",
result, a
)?;
ctx.push(result, RegisterType::I64);
}
"and" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = and i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"or" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = or i64 %{}, %{}", result, a, b)?;
ctx.push(result, RegisterType::I64);
}
"not" => {
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(&mut self.output, " %{} = xor i64 %{}, 1", result, a)?;
ctx.push(result, RegisterType::I64);
}
"i.<" | "i.lt" => self.emit_specialized_icmp(ctx, "slt")?,
"i.>" | "i.gt" => self.emit_specialized_icmp(ctx, "sgt")?,
"i.<=" | "i.lte" => self.emit_specialized_icmp(ctx, "sle")?,
"i.>=" | "i.gte" => self.emit_specialized_icmp(ctx, "sge")?,
"i.=" | "i.eq" => self.emit_specialized_icmp(ctx, "eq")?,
"i.<>" | "i.neq" => self.emit_specialized_icmp(ctx, "ne")?,
"f.+" | "f.add" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = fadd double %{}, %{}",
result, a, b
)?;
ctx.push(result, RegisterType::Double);
}
"f.-" | "f.subtract" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = fsub double %{}, %{}",
result, a, b
)?;
ctx.push(result, RegisterType::Double);
}
"f.*" | "f.multiply" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = fmul double %{}, %{}",
result, a, b
)?;
ctx.push(result, RegisterType::Double);
}
"f./" | "f.divide" => {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = fdiv double %{}, %{}",
result, a, b
)?;
ctx.push(result, RegisterType::Double);
}
"f.<" | "f.lt" => self.emit_specialized_fcmp(ctx, "olt")?,
"f.>" | "f.gt" => self.emit_specialized_fcmp(ctx, "ogt")?,
"f.<=" | "f.lte" => self.emit_specialized_fcmp(ctx, "ole")?,
"f.>=" | "f.gte" => self.emit_specialized_fcmp(ctx, "oge")?,
"f.=" | "f.eq" => self.emit_specialized_fcmp(ctx, "oeq")?,
"f.<>" | "f.neq" => self.emit_specialized_fcmp(ctx, "one")?,
_ if name == word_name => {
self.emit_specialized_recursive_call(ctx, word_name, sig, is_last)?;
}
_ if self.specialized_words.contains_key(name) => {
self.emit_specialized_word_dispatch(ctx, name)?;
}
_ => {
return Err(CodeGenError::Logic(format!(
"Unhandled operation in specialized codegen: {}",
name
)));
}
}
Ok(())
}
fn emit_specialized_icmp(
&mut self,
ctx: &mut RegisterContext,
cmp_op: &str,
) -> Result<(), CodeGenError> {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let cmp_result = self.fresh_temp();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = icmp {} i64 %{}, %{}",
cmp_result, cmp_op, a, b
)?;
writeln!(
&mut self.output,
" %{} = zext i1 %{} to i64",
result, cmp_result
)?;
ctx.push(result, RegisterType::I64);
Ok(())
}
fn emit_specialized_fcmp(
&mut self,
ctx: &mut RegisterContext,
cmp_op: &str,
) -> Result<(), CodeGenError> {
let (b, _) = ctx.pop().unwrap();
let (a, _) = ctx.pop().unwrap();
let cmp_result = self.fresh_temp();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = fcmp {} double %{}, %{}",
cmp_result, cmp_op, a, b
)?;
writeln!(
&mut self.output,
" %{} = zext i1 %{} to i64",
result, cmp_result
)?;
ctx.push(result, RegisterType::I64);
Ok(())
}
fn emit_specialized_safe_div(
&mut self,
ctx: &mut RegisterContext,
op: &str, ) -> Result<(), CodeGenError> {
let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap();
let is_zero = self.fresh_temp();
writeln!(&mut self.output, " %{} = icmp eq i64 %{}, 0", is_zero, b)?;
let (check_overflow, is_overflow) = if op == "sdiv" {
let is_int_min = self.fresh_temp();
let is_neg_one = self.fresh_temp();
let is_overflow = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = icmp eq i64 %{}, -9223372036854775808",
is_int_min, a
)?;
writeln!(
&mut self.output,
" %{} = icmp eq i64 %{}, -1",
is_neg_one, b
)?;
writeln!(
&mut self.output,
" %{} = and i1 %{}, %{}",
is_overflow, is_int_min, is_neg_one
)?;
(true, is_overflow)
} else {
(false, String::new())
};
let ok_label = self.fresh_block("div_ok");
let fail_label = self.fresh_block("div_fail");
let merge_label = self.fresh_block("div_merge");
let overflow_label = if check_overflow {
self.fresh_block("div_overflow")
} else {
String::new()
};
writeln!(
&mut self.output,
" br i1 %{}, label %{}, label %{}",
is_zero,
fail_label,
if check_overflow {
&overflow_label
} else {
&ok_label
}
)?;
if check_overflow {
writeln!(&mut self.output, "{}:", overflow_label)?;
writeln!(
&mut self.output,
" br i1 %{}, label %{}, label %{}",
is_overflow, merge_label, ok_label
)?;
}
writeln!(&mut self.output, "{}:", ok_label)?;
let ok_result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = {} i64 %{}, %{}",
ok_result, op, a, b
)?;
writeln!(&mut self.output, " br label %{}", merge_label)?;
writeln!(&mut self.output, "{}:", fail_label)?;
writeln!(&mut self.output, " br label %{}", merge_label)?;
writeln!(&mut self.output, "{}:", merge_label)?;
let result_phi = self.fresh_temp();
let success_phi = self.fresh_temp();
if check_overflow {
writeln!(
&mut self.output,
" %{} = phi i64 [ %{}, %{} ], [ 0, %{} ], [ -9223372036854775808, %{} ]",
result_phi, ok_result, ok_label, fail_label, overflow_label
)?;
writeln!(
&mut self.output,
" %{} = phi i64 [ 1, %{} ], [ 0, %{} ], [ 1, %{} ]",
success_phi, ok_label, fail_label, overflow_label
)?;
} else {
writeln!(
&mut self.output,
" %{} = phi i64 [ %{}, %{} ], [ 0, %{} ]",
result_phi, ok_result, ok_label, fail_label
)?;
writeln!(
&mut self.output,
" %{} = phi i64 [ 1, %{} ], [ 0, %{} ]",
success_phi, ok_label, fail_label
)?;
}
ctx.push(result_phi, RegisterType::I64);
ctx.push(success_phi, RegisterType::I64);
Ok(())
}
fn emit_specialized_safe_shift(
&mut self,
ctx: &mut RegisterContext,
is_left: bool, ) -> Result<(), CodeGenError> {
let (b, _) = ctx.pop().unwrap(); let (a, _) = ctx.pop().unwrap();
let is_negative = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = icmp slt i64 %{}, 0",
is_negative, b
)?;
let is_too_large = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = icmp sge i64 %{}, 64",
is_too_large, b
)?;
let is_invalid = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = or i1 %{}, %{}",
is_invalid, is_negative, is_too_large
)?;
let safe_count = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = select i1 %{}, i64 0, i64 %{}",
safe_count, is_invalid, b
)?;
let shift_result = self.fresh_temp();
let op = if is_left { "shl" } else { "lshr" };
writeln!(
&mut self.output,
" %{} = {} i64 %{}, %{}",
shift_result, op, a, safe_count
)?;
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = select i1 %{}, i64 0, i64 %{}",
result, is_invalid, shift_result
)?;
ctx.push(result, RegisterType::I64);
Ok(())
}
fn emit_specialized_recursive_call(
&mut self,
ctx: &mut RegisterContext,
word_name: &str,
sig: &SpecSignature,
is_tail: bool,
) -> Result<(), CodeGenError> {
let spec_name = format!("seq_{}{}", mangle_name(word_name), sig.suffix());
if ctx.values.len() < sig.inputs.len() {
return Err(CodeGenError::Logic(format!(
"Not enough values in context for recursive call to {}: need {}, have {}",
word_name,
sig.inputs.len(),
ctx.values.len()
)));
}
let mut args = Vec::new();
for _ in 0..sig.inputs.len() {
args.push(ctx.pop().unwrap());
}
args.reverse();
let arg_strs: Vec<String> = args
.iter()
.map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
.collect();
let return_type = sig.llvm_return_type();
if is_tail {
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = musttail call {} @{}({})",
result,
return_type,
spec_name,
arg_strs.join(", ")
)?;
writeln!(&mut self.output, " ret {} %{}", return_type, result)?;
} else {
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = call {} @{}({})",
result,
return_type,
spec_name,
arg_strs.join(", ")
)?;
if sig.outputs.len() == 1 {
ctx.push(result, sig.outputs[0]);
} else {
for (i, out_ty) in sig.outputs.iter().enumerate() {
let extracted = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = extractvalue {} %{}, {}",
extracted, return_type, result, i
)?;
ctx.push(extracted, *out_ty);
}
}
}
Ok(())
}
fn emit_specialized_word_dispatch(
&mut self,
ctx: &mut RegisterContext,
name: &str,
) -> Result<(), CodeGenError> {
let sig = self
.specialized_words
.get(name)
.ok_or_else(|| CodeGenError::Logic(format!("Unknown specialized word: {}", name)))?
.clone();
let spec_name = format!("seq_{}{}", mangle_name(name), sig.suffix());
let mut args = Vec::new();
for _ in 0..sig.inputs.len() {
args.push(ctx.pop().unwrap());
}
args.reverse();
let arg_strs: Vec<String> = args
.iter()
.map(|(var, ty)| format!("{} %{}", ty.llvm_type(), var))
.collect();
let return_type = sig.llvm_return_type();
let result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = call {} @{}({})",
result,
return_type,
spec_name,
arg_strs.join(", ")
)?;
if sig.outputs.len() == 1 {
ctx.push(result, sig.outputs[0]);
} else {
for (i, out_ty) in sig.outputs.iter().enumerate() {
let extracted = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = extractvalue {} %{}, {}",
extracted, return_type, result, i
)?;
ctx.push(extracted, *out_ty);
}
}
Ok(())
}
fn emit_specialized_return(
&mut self,
ctx: &RegisterContext,
sig: &SpecSignature,
) -> Result<(), CodeGenError> {
let output_count = sig.outputs.len();
if output_count == 0 {
writeln!(&mut self.output, " ret void")?;
} else if output_count == 1 {
let (var, ty) = ctx
.values
.last()
.ok_or_else(|| CodeGenError::Logic("Empty context at return".to_string()))?;
writeln!(&mut self.output, " ret {} %{}", ty.llvm_type(), var)?;
} else {
if ctx.values.len() < output_count {
return Err(CodeGenError::Logic(format!(
"Not enough values for multi-output return: need {}, have {}",
output_count,
ctx.values.len()
)));
}
let start_idx = ctx.values.len() - output_count;
let return_values: Vec<_> = ctx.values[start_idx..].to_vec();
let struct_type = sig.llvm_return_type();
let mut current_struct = "undef".to_string();
for (i, (var, ty)) in return_values.iter().enumerate() {
let new_struct = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = insertvalue {} {}, {} %{}, {}",
new_struct,
struct_type,
current_struct,
ty.llvm_type(),
var,
i
)?;
current_struct = format!("%{}", new_struct);
}
writeln!(&mut self.output, " ret {} {}", struct_type, current_struct)?;
}
Ok(())
}
fn codegen_specialized_if(
&mut self,
ctx: &mut RegisterContext,
then_branch: &[Statement],
else_branch: Option<&Vec<Statement>>,
word_name: &str,
sig: &SpecSignature,
is_last: bool,
) -> Result<(), CodeGenError> {
let (cond_var, _) = ctx
.pop()
.ok_or_else(|| CodeGenError::Logic("Empty context at if condition".to_string()))?;
let cmp_result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = icmp ne i64 %{}, 0",
cmp_result, cond_var
)?;
let then_label = self.fresh_block("if_then");
let else_label = self.fresh_block("if_else");
let merge_label = self.fresh_block("if_merge");
writeln!(
&mut self.output,
" br i1 %{}, label %{}, label %{}",
cmp_result, then_label, else_label
)?;
writeln!(&mut self.output, "{}:", then_label)?;
let mut then_ctx = ctx.clone();
let mut then_prev_int: Option<i64> = None;
for (i, stmt) in then_branch.iter().enumerate() {
let is_stmt_last = i == then_branch.len() - 1 && is_last;
self.codegen_specialized_statement(
&mut then_ctx,
stmt,
word_name,
sig,
is_stmt_last,
&mut then_prev_int,
)?;
}
if is_last && then_branch.is_empty() {
self.emit_specialized_return(&then_ctx, sig)?;
}
let then_emitted_return = is_last;
let then_pred = if then_emitted_return {
None
} else {
writeln!(&mut self.output, " br label %{}", merge_label)?;
Some(then_label.clone())
};
writeln!(&mut self.output, "{}:", else_label)?;
let mut else_ctx = ctx.clone();
let mut else_prev_int: Option<i64> = None;
if let Some(else_stmts) = else_branch {
for (i, stmt) in else_stmts.iter().enumerate() {
let is_stmt_last = i == else_stmts.len() - 1 && is_last;
self.codegen_specialized_statement(
&mut else_ctx,
stmt,
word_name,
sig,
is_stmt_last,
&mut else_prev_int,
)?;
}
}
if is_last && (else_branch.is_none() || else_branch.as_ref().is_some_and(|b| b.is_empty()))
{
self.emit_specialized_return(&else_ctx, sig)?;
}
let else_emitted_return = is_last;
let else_pred = if else_emitted_return {
None
} else {
writeln!(&mut self.output, " br label %{}", merge_label)?;
Some(else_label.clone())
};
if then_pred.is_some() || else_pred.is_some() {
writeln!(&mut self.output, "{}:", merge_label)?;
if let (Some(then_p), Some(else_p)) = (&then_pred, &else_pred) {
if then_ctx.values.len() != else_ctx.values.len() {
return Err(CodeGenError::Logic(format!(
"Stack depth mismatch in if branches: then has {}, else has {}",
then_ctx.values.len(),
else_ctx.values.len()
)));
}
ctx.values.clear();
for i in 0..then_ctx.values.len() {
let (then_var, then_ty) = &then_ctx.values[i];
let (else_var, else_ty) = &else_ctx.values[i];
if then_ty != else_ty {
return Err(CodeGenError::Logic(format!(
"Type mismatch at position {} in if branches: {:?} vs {:?}",
i, then_ty, else_ty
)));
}
if then_var == else_var {
ctx.push(then_var.clone(), *then_ty);
} else {
let phi_result = self.fresh_temp();
writeln!(
&mut self.output,
" %{} = phi {} [ %{}, %{} ], [ %{}, %{} ]",
phi_result,
then_ty.llvm_type(),
then_var,
then_p,
else_var,
else_p
)?;
ctx.push(phi_result, *then_ty);
}
}
} else if then_pred.is_some() {
*ctx = then_ctx;
} else {
*ctx = else_ctx;
}
if is_last && (then_pred.is_some() || else_pred.is_some()) {
self.emit_specialized_return(ctx, sig)?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_register_type_from_type() {
assert_eq!(RegisterType::from_type(&Type::Int), Some(RegisterType::I64));
assert_eq!(
RegisterType::from_type(&Type::Bool),
Some(RegisterType::I64)
);
assert_eq!(
RegisterType::from_type(&Type::Float),
Some(RegisterType::Double)
);
assert_eq!(RegisterType::from_type(&Type::String), None);
}
#[test]
fn test_spec_signature_suffix() {
let sig = SpecSignature {
inputs: vec![RegisterType::I64],
outputs: vec![RegisterType::I64],
};
assert_eq!(sig.suffix(), "_i64");
let sig2 = SpecSignature {
inputs: vec![RegisterType::Double],
outputs: vec![RegisterType::Double],
};
assert_eq!(sig2.suffix(), "_f64");
}
#[test]
fn test_register_context_stack_ops() {
let mut ctx = RegisterContext::new();
ctx.push("a".to_string(), RegisterType::I64);
ctx.push("b".to_string(), RegisterType::I64);
assert_eq!(ctx.len(), 2);
ctx.swap();
assert_eq!(ctx.values[0].0, "b");
assert_eq!(ctx.values[1].0, "a");
ctx.dup();
assert_eq!(ctx.len(), 3);
assert_eq!(ctx.values[2].0, "a");
ctx.drop();
assert_eq!(ctx.len(), 2);
}
}