use cel::common::ast::{CallExpr, IdedExpr};
use ferricel_types::functions::RuntimeFunction;
use walrus::{InstrSeqBuilder, ValType};
use crate::compiler::{
context::{CompilerContext, CompilerEnv},
expr::compile_expr,
operators::{compile_binary_op, compile_unary_op},
};
pub fn compile_ext_math_function(
func_name: &str,
call_expr: &CallExpr,
body: &mut InstrSeqBuilder,
env: &CompilerEnv,
ctx: &CompilerContext,
module: &mut walrus::Module,
) -> Result<(), anyhow::Error> {
match func_name {
"greatest" => compile_math_minmax(
call_expr,
body,
env,
ctx,
module,
RuntimeFunction::MathGreatest,
)?,
"least" => compile_math_minmax(
call_expr,
body,
env,
ctx,
module,
RuntimeFunction::MathLeast,
)?,
"ceil" => compile_unary_op(
call_expr,
"math.ceil()",
RuntimeFunction::MathCeil,
body,
env,
ctx,
module,
)?,
"floor" => compile_unary_op(
call_expr,
"math.floor()",
RuntimeFunction::MathFloor,
body,
env,
ctx,
module,
)?,
"round" => compile_unary_op(
call_expr,
"math.round()",
RuntimeFunction::MathRound,
body,
env,
ctx,
module,
)?,
"trunc" => compile_unary_op(
call_expr,
"math.trunc()",
RuntimeFunction::MathTrunc,
body,
env,
ctx,
module,
)?,
"abs" => compile_unary_op(
call_expr,
"math.abs()",
RuntimeFunction::MathAbs,
body,
env,
ctx,
module,
)?,
"sign" => compile_unary_op(
call_expr,
"math.sign()",
RuntimeFunction::MathSign,
body,
env,
ctx,
module,
)?,
"isInf" => compile_unary_op(
call_expr,
"math.isInf()",
RuntimeFunction::MathIsInf,
body,
env,
ctx,
module,
)?,
"isNaN" => compile_unary_op(
call_expr,
"math.isNaN()",
RuntimeFunction::MathIsNaN,
body,
env,
ctx,
module,
)?,
"isFinite" => compile_unary_op(
call_expr,
"math.isFinite()",
RuntimeFunction::MathIsFinite,
body,
env,
ctx,
module,
)?,
"bitOr" => compile_binary_op(
call_expr,
"math.bitOr()",
RuntimeFunction::MathBitOr,
body,
env,
ctx,
module,
)?,
"bitAnd" => compile_binary_op(
call_expr,
"math.bitAnd()",
RuntimeFunction::MathBitAnd,
body,
env,
ctx,
module,
)?,
"bitXor" => compile_binary_op(
call_expr,
"math.bitXor()",
RuntimeFunction::MathBitXor,
body,
env,
ctx,
module,
)?,
"bitNot" => compile_unary_op(
call_expr,
"math.bitNot()",
RuntimeFunction::MathBitNot,
body,
env,
ctx,
module,
)?,
"bitShiftLeft" => compile_binary_op(
call_expr,
"math.bitShiftLeft()",
RuntimeFunction::MathBitShiftLeft,
body,
env,
ctx,
module,
)?,
"bitShiftRight" => compile_binary_op(
call_expr,
"math.bitShiftRight()",
RuntimeFunction::MathBitShiftRight,
body,
env,
ctx,
module,
)?,
"sqrt" => compile_unary_op(
call_expr,
"math.sqrt()",
RuntimeFunction::MathSqrt,
body,
env,
ctx,
module,
)?,
_ => anyhow::bail!("Unknown math extension function: math.{}", func_name),
}
Ok(())
}
fn compile_math_minmax(
call_expr: &CallExpr,
body: &mut InstrSeqBuilder,
env: &CompilerEnv,
ctx: &CompilerContext,
module: &mut walrus::Module,
runtime_fn: RuntimeFunction,
) -> Result<(), anyhow::Error> {
if call_expr.args.is_empty() {
anyhow::bail!("math.{}() requires at least one argument", runtime_fn);
}
if call_expr.args.len() == 1 {
compile_expr(&call_expr.args[0].expr, body, env, ctx, module)?;
} else {
compile_args_as_array(&call_expr.args, body, env, ctx, module)?;
}
body.call(env.get(runtime_fn));
Ok(())
}
fn compile_args_as_array(
args: &[IdedExpr],
body: &mut InstrSeqBuilder,
env: &CompilerEnv,
ctx: &CompilerContext,
module: &mut walrus::Module,
) -> Result<(), anyhow::Error> {
body.call(env.get(RuntimeFunction::CreateArray));
let array_local = module.locals.add(ValType::I32);
body.local_set(array_local);
for arg in args {
compile_expr(&arg.expr, body, env, ctx, module)?;
let elem_local = module.locals.add(ValType::I32);
body.local_set(elem_local);
body.local_get(array_local);
body.local_get(elem_local);
body.call(env.get(RuntimeFunction::ArrayPush));
}
body.local_get(array_local);
Ok(())
}