use std::collections::HashMap;
use std::fs;
use std::path::Path;
use wasm_encoder::{
CodeSection, ExportKind, ExportSection, Function, FunctionSection, Ieee64, Instruction, Module, TypeSection, ValType,
};
use crate::builtins::syntax::get_raw_args_fn;
use crate::calcit::{Calcit, CalcitArgLabel, CalcitFnArgs, CalcitLocal, CalcitProc, CalcitSyntax};
use crate::program;
fn f64_const(v: f64) -> Instruction<'static> {
Instruction::F64Const(Ieee64::from(v))
}
pub fn emit_wasm(init_ns: &str, emit_path: &str) -> Result<(), String> {
let program_data = program::clone_compiled_program_snapshot()?;
let mut compiled_fns: Vec<CompiledFn> = Vec::new();
if let Some(file_info) = program_data.get(init_ns) {
let mut fn_defs: Vec<(String, CalcitFnArgs, Vec<Calcit>)> = Vec::new();
for (def_name, compiled) in &file_info.defs {
if compiled.kind != program::CompiledDefKind::Fn {
continue;
}
match extract_fn_parts(&compiled.preprocessed_code) {
Ok((args, body)) => {
fn_defs.push((def_name.to_string(), args, body));
}
Err(e) => {
eprintln!("[wasm] skipping {init_ns}/{def_name}: {e}");
}
}
}
let fn_index: HashMap<String, u32> = fn_defs
.iter()
.enumerate()
.map(|(i, (name, _, _))| (name.clone(), i as u32))
.collect();
for (def_name, args, body) in &fn_defs {
match compile_fn(def_name, args, body, &fn_index) {
Ok(func) => compiled_fns.push(func),
Err(e) => {
eprintln!("[wasm] skipping {init_ns}/{def_name}: {e}");
let arity = match args {
CalcitFnArgs::Args(v) => v.len(),
CalcitFnArgs::MarkedArgs(v) => v.len(),
};
compiled_fns.push(CompiledFn {
name: def_name.clone(),
arity,
locals: vec![],
instructions: vec![f64_const(0.0)],
});
}
}
}
} else {
return Err(format!("namespace not found: {init_ns}"));
}
if compiled_fns.is_empty() {
return Err("no functions could be compiled to WASM".into());
}
let wasm_bytes = build_wasm_module(&compiled_fns)?;
let out_path = Path::new(emit_path);
if !out_path.exists() {
fs::create_dir_all(out_path).map_err(|e| format!("failed to create dir: {e}"))?;
}
let wasm_file = out_path.join("program.wasm");
fs::write(&wasm_file, &wasm_bytes).map_err(|e| format!("failed to write WASM: {e}"))?;
println!("wrote WASM to: {}", wasm_file.display());
Ok(())
}
struct CompiledFn {
name: String,
arity: usize,
locals: Vec<ValType>,
instructions: Vec<Instruction<'static>>,
}
fn build_wasm_module(fns: &[CompiledFn]) -> Result<Vec<u8>, String> {
let mut module = Module::new();
let mut types = TypeSection::new();
for f in fns {
let params: Vec<ValType> = vec![ValType::F64; f.arity];
types.ty().function(params, vec![ValType::F64]);
}
module.section(&types);
let mut functions = FunctionSection::new();
for (i, _) in fns.iter().enumerate() {
functions.function(i as u32);
}
module.section(&functions);
let mut exports = ExportSection::new();
for (i, f) in fns.iter().enumerate() {
exports.export(&f.name, ExportKind::Func, i as u32);
}
module.section(&exports);
let mut codes = CodeSection::new();
for f in fns {
let locals: Vec<(u32, ValType)> = if f.locals.is_empty() {
vec![]
} else {
let mut groups = Vec::new();
let mut count = 1u32;
let mut prev = f.locals[0];
for &t in &f.locals[1..] {
if t == prev {
count += 1;
} else {
groups.push((count, prev));
prev = t;
count = 1;
}
}
groups.push((count, prev));
groups
};
let mut func = Function::new(locals);
for instr in &f.instructions {
func.instruction(instr);
}
func.instruction(&Instruction::End);
codes.function(&func);
}
module.section(&codes);
Ok(module.finish())
}
fn extract_fn_parts(code: &Calcit) -> Result<(CalcitFnArgs, Vec<Calcit>), String> {
let Calcit::List(items) = code else {
return Err(format!("expected preprocessed defn list, got: {code}"));
};
match (items.first(), items.get(1), items.get(2)) {
(Some(Calcit::Syntax(CalcitSyntax::Defn, _)), Some(Calcit::Symbol { .. }), Some(Calcit::List(args))) => {
let raw_args = get_raw_args_fn(args)?;
Ok((raw_args, items.drop_left().drop_left().drop_left().to_vec()))
}
_ => Err(format!("expected preprocessed defn form, got: {code}")),
}
}
struct WasmGenCtx {
locals: HashMap<String, u32>,
extra_locals: Vec<ValType>,
next_local: u32,
uses_recur: bool,
arg_indices: Vec<u32>,
instructions: Vec<Instruction<'static>>,
fn_index: HashMap<String, u32>,
block_depth: u32,
}
impl WasmGenCtx {
fn new(num_params: u32, fn_index: HashMap<String, u32>) -> Self {
WasmGenCtx {
locals: HashMap::new(),
extra_locals: Vec::new(),
next_local: num_params,
uses_recur: false,
arg_indices: Vec::new(),
instructions: Vec::new(),
fn_index,
block_depth: 0,
}
}
fn alloc_local(&mut self) -> u32 {
let idx = self.next_local;
self.next_local += 1;
self.extra_locals.push(ValType::F64);
idx
}
fn declare_local(&mut self, name: &str) -> u32 {
let idx = self.alloc_local();
self.locals.insert(name.to_owned(), idx);
idx
}
fn emit(&mut self, instr: Instruction<'static>) {
self.instructions.push(instr);
}
}
fn compile_fn(name: &str, args: &CalcitFnArgs, body: &[Calcit], fn_index: &HashMap<String, u32>) -> Result<CompiledFn, String> {
let mut param_names = Vec::new();
match args {
CalcitFnArgs::Args(idxs) => {
for idx in idxs {
param_names.push(CalcitLocal::read_name(*idx));
}
}
CalcitFnArgs::MarkedArgs(labels) => {
for label in labels {
match label {
CalcitArgLabel::Idx(idx) => {
param_names.push(CalcitLocal::read_name(*idx));
}
CalcitArgLabel::OptionalMark | CalcitArgLabel::RestMark => {
return Err("optional/rest args not supported in WASM codegen".into());
}
}
}
}
}
let arity = param_names.len();
let mut ctx = WasmGenCtx::new(arity as u32, fn_index.clone());
for (i, pname) in param_names.iter().enumerate() {
ctx.locals.insert(pname.clone(), i as u32);
ctx.arg_indices.push(i as u32);
}
ctx.uses_recur = body.iter().any(check_uses_recur);
if ctx.uses_recur {
ctx.emit(Instruction::Loop(wasm_encoder::BlockType::Result(ValType::F64)));
emit_body(&mut ctx, body)?;
ctx.emit(Instruction::End); } else {
emit_body(&mut ctx, body)?;
}
Ok(CompiledFn {
name: name.to_owned(),
arity,
locals: ctx.extra_locals,
instructions: ctx.instructions,
})
}
fn check_uses_recur(expr: &Calcit) -> bool {
match expr {
Calcit::Proc(CalcitProc::Recur) => true,
Calcit::List(xs) => {
if let Some(Calcit::Syntax(CalcitSyntax::Defn, _)) = xs.first() {
return false;
}
xs.iter().any(check_uses_recur)
}
_ => false,
}
}
fn emit_body(ctx: &mut WasmGenCtx, exprs: &[Calcit]) -> Result<(), String> {
if exprs.is_empty() {
ctx.emit(f64_const(0.0));
return Ok(());
}
for (i, expr) in exprs.iter().enumerate() {
emit_expr(ctx, expr)?;
if i < exprs.len() - 1 {
ctx.emit(Instruction::Drop);
}
}
Ok(())
}
fn emit_expr(ctx: &mut WasmGenCtx, expr: &Calcit) -> Result<(), String> {
match expr {
Calcit::Number(n) => {
ctx.emit(f64_const(*n));
}
Calcit::Bool(true) => {
ctx.emit(f64_const(1.0));
}
Calcit::Bool(false) | Calcit::Nil => {
ctx.emit(f64_const(0.0));
}
Calcit::Local(local) => {
let name = &*local.sym;
let idx = *ctx.locals.get(name).ok_or_else(|| format!("undefined local variable: {name}"))?;
ctx.emit(Instruction::LocalGet(idx));
}
Calcit::List(xs) if !xs.is_empty() => {
emit_call_expr(ctx, xs)?;
}
_ => return Err(format!("unsupported WASM expression: {expr}")),
}
Ok(())
}
fn emit_call_expr(ctx: &mut WasmGenCtx, xs: &crate::calcit::CalcitList) -> Result<(), String> {
let head = &xs[0];
let args_list: Vec<Calcit> = xs.drop_left().to_vec();
match head {
Calcit::Syntax(syn, _) => match syn {
CalcitSyntax::If => emit_if(ctx, &args_list),
CalcitSyntax::CoreLet => emit_let(ctx, &args_list),
CalcitSyntax::Defn => Err("nested defn not supported in WASM".into()),
_ => Err(format!("unsupported syntax in WASM: {syn}")),
},
Calcit::Proc(proc) => emit_proc_call(ctx, proc, &args_list),
Calcit::Import(import) => {
let fn_idx = *ctx
.fn_index
.get(import.def.as_ref())
.ok_or_else(|| format!("unknown function: {}", import.def))?;
for arg in &args_list {
emit_expr(ctx, arg)?;
}
ctx.emit(Instruction::Call(fn_idx));
Ok(())
}
Calcit::Symbol { sym, .. } => {
let fn_idx = *ctx.fn_index.get(sym.as_ref()).ok_or_else(|| format!("unknown function: {sym}"))?;
for arg in &args_list {
emit_expr(ctx, arg)?;
}
ctx.emit(Instruction::Call(fn_idx));
Ok(())
}
_ => Err(format!("unsupported call head in WASM: {head}")),
}
}
fn emit_proc_call(ctx: &mut WasmGenCtx, proc: &CalcitProc, args: &[Calcit]) -> Result<(), String> {
match proc {
CalcitProc::NativeAdd => emit_binary(ctx, Instruction::F64Add, args),
CalcitProc::NativeMinus => emit_binary(ctx, Instruction::F64Sub, args),
CalcitProc::NativeMultiply => emit_binary(ctx, Instruction::F64Mul, args),
CalcitProc::NativeDivide => emit_binary(ctx, Instruction::F64Div, args),
CalcitProc::NativeNumberRem => {
if args.len() != 2 {
return Err("rem expects 2 args".into());
}
emit_expr(ctx, &args[0])?; emit_expr(ctx, &args[0])?; emit_expr(ctx, &args[1])?; ctx.emit(Instruction::F64Div);
ctx.emit(Instruction::F64Trunc);
emit_expr(ctx, &args[1])?; ctx.emit(Instruction::F64Mul);
ctx.emit(Instruction::F64Sub);
Ok(())
}
CalcitProc::NativeLessThan => emit_cmp(ctx, Instruction::F64Lt, args),
CalcitProc::NativeGreaterThan => emit_cmp(ctx, Instruction::F64Gt, args),
CalcitProc::NativeEquals | CalcitProc::Identical => emit_cmp(ctx, Instruction::F64Eq, args),
CalcitProc::Not => {
if args.len() != 1 {
return Err("not expects 1 arg".into());
}
ctx.emit(f64_const(1.0)); ctx.emit(f64_const(0.0)); emit_expr(ctx, &args[0])?;
ctx.emit(f64_const(0.0));
ctx.emit(Instruction::F64Eq); ctx.emit(Instruction::Select);
Ok(())
}
CalcitProc::Floor => emit_unary(ctx, Instruction::F64Floor, args),
CalcitProc::Ceil => emit_unary(ctx, Instruction::F64Ceil, args),
CalcitProc::Round => emit_unary(ctx, Instruction::F64Nearest, args),
CalcitProc::Sqrt => emit_unary(ctx, Instruction::F64Sqrt, args),
CalcitProc::Sin | CalcitProc::Cos => Err(format!("trigonometric function {proc} not available in WASM (no f64.sin/cos)")),
CalcitProc::Pow => Err("pow not yet supported in WASM codegen (no f64.pow instruction)".into()),
CalcitProc::Recur => {
if args.len() != ctx.arg_indices.len() {
return Err(format!(
"recur arity mismatch: expected {}, got {}",
ctx.arg_indices.len(),
args.len()
));
}
let mut temps = Vec::new();
for arg in args {
let tmp = ctx.alloc_local();
emit_expr(ctx, arg)?;
ctx.emit(Instruction::LocalSet(tmp));
temps.push(tmp);
}
for (i, &tmp) in temps.iter().enumerate() {
ctx.emit(Instruction::LocalGet(tmp));
ctx.emit(Instruction::LocalSet(ctx.arg_indices[i]));
}
ctx.emit(Instruction::Br(ctx.block_depth)); ctx.emit(Instruction::Unreachable);
Ok(())
}
_ => Err(format!("unsupported proc in WASM: {proc}")),
}
}
fn emit_unary(ctx: &mut WasmGenCtx, instr: Instruction<'static>, args: &[Calcit]) -> Result<(), String> {
if args.len() != 1 {
return Err(format!("{instr:?} expects 1 arg, got {}", args.len()));
}
emit_expr(ctx, &args[0])?;
ctx.emit(instr);
Ok(())
}
fn emit_binary(ctx: &mut WasmGenCtx, instr: Instruction<'static>, args: &[Calcit]) -> Result<(), String> {
if args.len() != 2 {
return Err(format!("{instr:?} expects 2 args, got {}", args.len()));
}
emit_expr(ctx, &args[0])?;
emit_expr(ctx, &args[1])?;
ctx.emit(instr);
Ok(())
}
fn emit_cmp(ctx: &mut WasmGenCtx, instr: Instruction<'static>, args: &[Calcit]) -> Result<(), String> {
if args.len() != 2 {
return Err(format!("{instr:?} expects 2 args, got {}", args.len()));
}
ctx.emit(f64_const(1.0));
ctx.emit(f64_const(0.0));
emit_expr(ctx, &args[0])?;
emit_expr(ctx, &args[1])?;
ctx.emit(instr);
ctx.emit(Instruction::Select);
Ok(())
}
fn emit_if(ctx: &mut WasmGenCtx, args: &[Calcit]) -> Result<(), String> {
if args.len() < 2 || args.len() > 3 {
return Err(format!("if expects 2-3 args, got {}", args.len()));
}
emit_expr(ctx, &args[0])?;
ctx.emit(f64_const(0.0));
ctx.emit(Instruction::F64Ne);
ctx.emit(Instruction::If(wasm_encoder::BlockType::Result(ValType::F64)));
ctx.block_depth += 1;
emit_expr(ctx, &args[1])?;
ctx.emit(Instruction::Else);
if args.len() == 3 {
emit_expr(ctx, &args[2])?;
} else {
ctx.emit(f64_const(0.0));
}
ctx.block_depth -= 1;
ctx.emit(Instruction::End);
Ok(())
}
fn emit_let(ctx: &mut WasmGenCtx, body: &[Calcit]) -> Result<(), String> {
if body.is_empty() {
ctx.emit(f64_const(0.0));
return Ok(());
}
let pair = &body[0];
let rest = &body[1..];
match pair {
Calcit::Nil => emit_body(ctx, rest),
Calcit::List(xs) if xs.is_empty() => emit_body(ctx, rest),
Calcit::List(xs) if xs.len() == 2 => {
let var_name = match &xs[0] {
Calcit::Local(CalcitLocal { sym, .. }) => sym.to_string(),
Calcit::Symbol { sym, .. } => sym.to_string(),
other => return Err(format!("let binding expected symbol, got: {other}")),
};
emit_expr(ctx, &xs[1])?;
let idx = ctx.declare_local(&var_name);
ctx.emit(Instruction::LocalSet(idx));
if rest.len() == 1 {
if let Calcit::List(inner) = &rest[0] {
if let Some(Calcit::Syntax(CalcitSyntax::CoreLet, _)) = inner.first() {
let inner_body: Vec<Calcit> = inner.drop_left().to_vec();
return emit_let(ctx, &inner_body);
}
}
}
emit_body(ctx, rest)
}
_ => Err(format!("unsupported let binding form: {pair}")),
}
}