use super::{JitCompileError, JitSig, JitType};
use alloc::collections::BTreeSet;
use cranelift::codegen::ir::FuncRef;
use cranelift::prelude::*;
use num_traits::cast::ToPrimitive;
use rustpython_compiler_core::bytecode::{
self, BinaryOperator, BorrowedConstant, CodeObject, ComparisonOperator, Instruction,
IntrinsicFunction1, Label, OpArg, OpArgState, oparg,
};
use std::collections::HashMap;
#[repr(u16)]
enum CustomTrapCode {
NegativeShiftCount = 1,
}
#[derive(Clone)]
struct Local {
var: Variable,
ty: JitType,
}
#[derive(Debug)]
enum JitValue {
Int(Value),
Float(Value),
Bool(Value),
None,
Null,
Tuple(Vec<JitValue>),
FuncRef(FuncRef),
}
impl JitValue {
fn from_type_and_value(ty: JitType, val: Value) -> JitValue {
match ty {
JitType::Int => JitValue::Int(val),
JitType::Float => JitValue::Float(val),
JitType::Bool => JitValue::Bool(val),
}
}
fn to_jit_type(&self) -> Option<JitType> {
match self {
JitValue::Int(_) => Some(JitType::Int),
JitValue::Float(_) => Some(JitType::Float),
JitValue::Bool(_) => Some(JitType::Bool),
JitValue::None | JitValue::Null | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
}
}
fn into_value(self) -> Option<Value> {
match self {
JitValue::Int(val) | JitValue::Float(val) | JitValue::Bool(val) => Some(val),
JitValue::None | JitValue::Null | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
}
}
}
#[derive(Clone)]
struct DDValue {
hi: Value,
lo: Value,
}
pub struct FunctionCompiler<'a, 'b> {
builder: &'a mut FunctionBuilder<'b>,
stack: Vec<JitValue>,
variables: Box<[Option<Local>]>,
label_to_block: HashMap<Label, Block>,
pub(crate) sig: JitSig,
}
impl<'a, 'b> FunctionCompiler<'a, 'b> {
pub fn new(
builder: &'a mut FunctionBuilder<'b>,
num_variables: usize,
arg_types: &[JitType],
ret_type: Option<JitType>,
entry_block: Block,
) -> FunctionCompiler<'a, 'b> {
let mut compiler = FunctionCompiler {
builder,
stack: Vec::new(),
variables: vec![None; num_variables].into_boxed_slice(),
label_to_block: HashMap::new(),
sig: JitSig {
args: arg_types.to_vec(),
ret: ret_type,
},
};
let params = compiler.builder.func.dfg.block_params(entry_block).to_vec();
for (i, (ty, val)) in arg_types.iter().zip(params).enumerate() {
compiler
.store_variable(
(i as u32).into(),
JitValue::from_type_and_value(ty.clone(), val),
)
.unwrap();
}
compiler
}
fn pop_multiple(&mut self, count: usize) -> Vec<JitValue> {
let stack_len = self.stack.len();
self.stack.drain(stack_len - count..).collect()
}
fn store_variable(&mut self, idx: oparg::VarNum, val: JitValue) -> Result<(), JitCompileError> {
let builder = &mut self.builder;
let ty = val.to_jit_type().ok_or(JitCompileError::NotSupported)?;
let local = self.variables[idx].get_or_insert_with(|| {
let var = builder.declare_var(ty.to_cranelift());
Local {
var,
ty: ty.clone(),
}
});
if ty != local.ty {
Err(JitCompileError::NotSupported)
} else {
self.builder.def_var(local.var, val.into_value().unwrap());
Ok(())
}
}
fn boolean_val(&mut self, val: JitValue) -> Result<Value, JitCompileError> {
match val {
JitValue::Float(val) => {
let zero = self.builder.ins().f64const(0.0);
let val = self.builder.ins().fcmp(FloatCC::NotEqual, val, zero);
Ok(val)
}
JitValue::Int(val) => {
let zero = self.builder.ins().iconst(types::I64, 0);
let val = self.builder.ins().icmp(IntCC::NotEqual, val, zero);
Ok(val)
}
JitValue::Bool(val) => Ok(val),
JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)),
JitValue::Null | JitValue::Tuple(_) | JitValue::FuncRef(_) => {
Err(JitCompileError::NotSupported)
}
}
}
fn get_or_create_block(&mut self, label: Label) -> Block {
let builder = &mut self.builder;
*self
.label_to_block
.entry(label)
.or_insert_with(|| builder.create_block())
}
fn jump_target_forward(offset: u32, caches: u32, arg: OpArg) -> Result<Label, JitCompileError> {
let after = offset
.checked_add(1)
.and_then(|i| i.checked_add(caches))
.ok_or(JitCompileError::BadBytecode)?;
let target = after
.checked_add(u32::from(arg))
.ok_or(JitCompileError::BadBytecode)?;
Ok(Label::from_u32(target))
}
fn jump_target_backward(
offset: u32,
caches: u32,
arg: OpArg,
) -> Result<Label, JitCompileError> {
let after = offset
.checked_add(1)
.and_then(|i| i.checked_add(caches))
.ok_or(JitCompileError::BadBytecode)?;
let target = after
.checked_sub(u32::from(arg))
.ok_or(JitCompileError::BadBytecode)?;
Ok(Label::from_u32(target))
}
fn instruction_target(
offset: u32,
instruction: Instruction,
arg: OpArg,
) -> Result<Option<Label>, JitCompileError> {
let caches = instruction.cache_entries() as u32;
let target = match instruction {
Instruction::JumpForward { .. } => {
Some(Self::jump_target_forward(offset, caches, arg)?)
}
Instruction::JumpBackward { .. } | Instruction::JumpBackwardNoInterrupt { .. } => {
Some(Self::jump_target_backward(offset, caches, arg)?)
}
Instruction::PopJumpIfFalse { .. }
| Instruction::PopJumpIfTrue { .. }
| Instruction::PopJumpIfNone { .. }
| Instruction::PopJumpIfNotNone { .. }
| Instruction::ForIter { .. }
| Instruction::Send { .. } => Some(Self::jump_target_forward(offset, caches, arg)?),
_ => None,
};
Ok(target)
}
pub fn compile<C: bytecode::Constant>(
&mut self,
func_ref: FuncRef,
bytecode: &CodeObject<C>,
) -> Result<(), JitCompileError> {
let clean_instructions: bytecode::CodeUnits = bytecode
.instructions
.original_bytes()
.as_slice()
.try_into()
.map_err(|_| JitCompileError::BadBytecode)?;
let mut label_targets = BTreeSet::new();
let mut target_arg_state = OpArgState::default();
for (offset, &raw_instr) in clean_instructions.iter().enumerate() {
let (instruction, arg) = target_arg_state.get(raw_instr);
if let Some(target) = Self::instruction_target(offset as u32, instruction, arg)? {
label_targets.insert(target);
}
}
let mut arg_state = OpArgState::default();
let mut in_unreachable_code = false;
for (offset, &raw_instr) in clean_instructions.iter().enumerate() {
let label = Label::from_u32(offset as u32);
let (instruction, arg) = arg_state.get(raw_instr);
if label_targets.contains(&label) {
let target_block = self.get_or_create_block(label);
if let Some(cur) = self.builder.current_block()
&& cur != target_block
{
let needs_terminator = match self.builder.func.layout.last_inst(cur) {
None => true, Some(inst) => {
!self.builder.func.dfg.insts[inst].opcode().is_terminator()
}
};
if needs_terminator {
self.builder.ins().jump(target_block, &[]);
}
}
if self.builder.current_block() != Some(target_block) {
self.builder.switch_to_block(target_block);
}
in_unreachable_code = false;
}
if in_unreachable_code {
continue;
}
self.add_instruction(func_ref, bytecode, offset as u32, instruction, arg)?;
match instruction {
Instruction::ReturnValue
| Instruction::JumpBackward { .. }
| Instruction::JumpBackwardNoInterrupt { .. }
| Instruction::JumpForward { .. } => {
in_unreachable_code = true;
}
_ => {}
}
}
if let Some(cur) = self.builder.current_block() {
let needs_terminator = match self.builder.func.layout.last_inst(cur) {
None => true,
Some(inst) => !self.builder.func.dfg.insts[inst].opcode().is_terminator(),
};
if needs_terminator {
self.builder.ins().trap(TrapCode::user(0).unwrap());
}
}
Ok(())
}
fn prepare_const<C: bytecode::Constant>(
&mut self,
constant: BorrowedConstant<'_, C>,
) -> Result<JitValue, JitCompileError> {
let value = match constant {
BorrowedConstant::Integer { value } => {
let val = self.builder.ins().iconst(
types::I64,
value.to_i64().ok_or(JitCompileError::NotSupported)?,
);
JitValue::Int(val)
}
BorrowedConstant::Float { value } => {
let val = self.builder.ins().f64const(value);
JitValue::Float(val)
}
BorrowedConstant::Boolean { value } => {
let val = self.builder.ins().iconst(types::I8, value as i64);
JitValue::Bool(val)
}
BorrowedConstant::None => JitValue::None,
_ => return Err(JitCompileError::NotSupported),
};
Ok(value)
}
fn return_value(&mut self, val: JitValue) -> Result<(), JitCompileError> {
if let Some(ref ty) = self.sig.ret {
if val.to_jit_type().as_ref() != Some(ty) {
return Err(JitCompileError::NotSupported);
}
} else {
let ty = val.to_jit_type().ok_or(JitCompileError::NotSupported)?;
self.sig.ret = Some(ty.clone());
self.builder
.func
.signature
.returns
.push(AbiParam::new(ty.to_cranelift()));
}
let cr_val = val.into_value().ok_or(JitCompileError::NotSupported)?;
self.builder.ins().return_(&[cr_val]);
Ok(())
}
pub fn add_instruction<C: bytecode::Constant>(
&mut self,
func_ref: FuncRef,
bytecode: &CodeObject<C>,
offset: u32,
instruction: Instruction,
arg: OpArg,
) -> Result<(), JitCompileError> {
match instruction {
Instruction::BinaryOp { op } => {
let op = op.get(arg);
let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let a_type = a.to_jit_type();
let b_type = b.to_jit_type();
let val = match (op, a, b) {
(
BinaryOperator::Add | BinaryOperator::InplaceAdd,
JitValue::Int(a),
JitValue::Int(b),
) => {
let (out, carry) = self.builder.ins().sadd_overflow(a, b);
self.builder.ins().trapnz(carry, TrapCode::INTEGER_OVERFLOW);
JitValue::Int(out)
}
(
BinaryOperator::Subtract | BinaryOperator::InplaceSubtract,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.compile_sub(a, b)),
(
BinaryOperator::FloorDivide | BinaryOperator::InplaceFloorDivide,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().sdiv(a, b)),
(
BinaryOperator::TrueDivide | BinaryOperator::InplaceTrueDivide,
JitValue::Int(a),
JitValue::Int(b),
) => {
self.builder
.ins()
.trapz(b, TrapCode::INTEGER_DIVISION_BY_ZERO);
let a_float = self.builder.ins().fcvt_from_sint(types::F64, a);
let b_float = self.builder.ins().fcvt_from_sint(types::F64, b);
JitValue::Float(self.builder.ins().fdiv(a_float, b_float))
}
(
BinaryOperator::Multiply | BinaryOperator::InplaceMultiply,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().imul(a, b)),
(
BinaryOperator::Remainder | BinaryOperator::InplaceRemainder,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().srem(a, b)),
(
BinaryOperator::Power | BinaryOperator::InplacePower,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.compile_ipow(a, b)),
(
BinaryOperator::Lshift | BinaryOperator::Rshift,
JitValue::Int(a),
JitValue::Int(b),
) => {
let sign = self.builder.ins().ushr_imm(b, 63);
self.builder.ins().trapnz(
sign,
TrapCode::user(CustomTrapCode::NegativeShiftCount as u8).unwrap(),
);
let out =
if matches!(op, BinaryOperator::Lshift | BinaryOperator::InplaceLshift)
{
self.builder.ins().ishl(a, b)
} else {
self.builder.ins().sshr(a, b)
};
JitValue::Int(out)
}
(
BinaryOperator::And | BinaryOperator::InplaceAnd,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().band(a, b)),
(
BinaryOperator::Or | BinaryOperator::InplaceOr,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().bor(a, b)),
(
BinaryOperator::Xor | BinaryOperator::InplaceXor,
JitValue::Int(a),
JitValue::Int(b),
) => JitValue::Int(self.builder.ins().bxor(a, b)),
(
BinaryOperator::Add | BinaryOperator::InplaceAdd,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fadd(a, b)),
(
BinaryOperator::Subtract | BinaryOperator::InplaceSubtract,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fsub(a, b)),
(
BinaryOperator::Multiply | BinaryOperator::InplaceMultiply,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fmul(a, b)),
(
BinaryOperator::TrueDivide | BinaryOperator::InplaceTrueDivide,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.builder.ins().fdiv(a, b)),
(
BinaryOperator::Power | BinaryOperator::InplacePower,
JitValue::Float(a),
JitValue::Float(b),
) => JitValue::Float(self.compile_fpow(a, b)),
(_, JitValue::Int(a), JitValue::Float(b))
| (_, JitValue::Float(a), JitValue::Int(b)) => {
let operand_one = match a_type.unwrap() {
JitType::Int => self.builder.ins().fcvt_from_sint(types::F64, a),
_ => a,
};
let operand_two = match b_type.unwrap() {
JitType::Int => self.builder.ins().fcvt_from_sint(types::F64, b),
_ => b,
};
match op {
BinaryOperator::Add | BinaryOperator::InplaceAdd => {
JitValue::Float(self.builder.ins().fadd(operand_one, operand_two))
}
BinaryOperator::Subtract | BinaryOperator::InplaceSubtract => {
JitValue::Float(self.builder.ins().fsub(operand_one, operand_two))
}
BinaryOperator::Multiply | BinaryOperator::InplaceMultiply => {
JitValue::Float(self.builder.ins().fmul(operand_one, operand_two))
}
BinaryOperator::TrueDivide | BinaryOperator::InplaceTrueDivide => {
JitValue::Float(self.builder.ins().fdiv(operand_one, operand_two))
}
BinaryOperator::Power | BinaryOperator::InplacePower => {
JitValue::Float(self.compile_fpow(operand_one, operand_two))
}
_ => return Err(JitCompileError::NotSupported),
}
}
_ => return Err(JitCompileError::NotSupported),
};
self.stack.push(val);
Ok(())
}
Instruction::BuildTuple { count } => {
let elements = self.pop_multiple(count.get(arg) as usize);
self.stack.push(JitValue::Tuple(elements));
Ok(())
}
Instruction::Call { argc } => {
let nargs = argc.get(arg);
let mut args = Vec::new();
for _ in 0..nargs {
let arg = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
args.push(arg.into_value().unwrap());
}
let self_or_null = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
if !matches!(self_or_null, JitValue::Null) {
return Err(JitCompileError::NotSupported);
}
match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
JitValue::FuncRef(reference) => {
let call = self.builder.ins().call(reference, &args);
let returns = self.builder.inst_results(call);
self.stack.push(JitValue::Int(returns[0]));
Ok(())
}
_ => Err(JitCompileError::BadBytecode),
}
}
Instruction::PushNull => {
self.stack.push(JitValue::Null);
Ok(())
}
Instruction::CallIntrinsic1 { func } => {
match func.get(arg) {
IntrinsicFunction1::UnaryPositive => {
match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
JitValue::Int(val) => {
self.stack.push(JitValue::Int(val));
Ok(())
}
_ => Err(JitCompileError::NotSupported),
}
}
_ => Err(JitCompileError::NotSupported),
}
}
Instruction::CompareOp { opname } => {
let op = opname.get(arg);
let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let a_type: Option<JitType> = a.to_jit_type();
let b_type: Option<JitType> = b.to_jit_type();
match (a, b) {
(JitValue::Int(a), JitValue::Int(b))
| (JitValue::Bool(a), JitValue::Bool(b))
| (JitValue::Bool(a), JitValue::Int(b))
| (JitValue::Int(a), JitValue::Bool(b)) => {
let operand_one = match a_type.unwrap() {
JitType::Bool => self.builder.ins().uextend(types::I64, a),
_ => a,
};
let operand_two = match b_type.unwrap() {
JitType::Bool => self.builder.ins().uextend(types::I64, b),
_ => b,
};
let cond = match op {
ComparisonOperator::Equal => IntCC::Equal,
ComparisonOperator::NotEqual => IntCC::NotEqual,
ComparisonOperator::Less => IntCC::SignedLessThan,
ComparisonOperator::LessOrEqual => IntCC::SignedLessThanOrEqual,
ComparisonOperator::Greater => IntCC::SignedGreaterThan,
ComparisonOperator::GreaterOrEqual => IntCC::SignedGreaterThanOrEqual,
};
let val = self.builder.ins().icmp(cond, operand_one, operand_two);
self.stack.push(JitValue::Bool(val));
Ok(())
}
(JitValue::Float(a), JitValue::Float(b)) => {
let cond = match op {
ComparisonOperator::Equal => FloatCC::Equal,
ComparisonOperator::NotEqual => FloatCC::NotEqual,
ComparisonOperator::Less => FloatCC::LessThan,
ComparisonOperator::LessOrEqual => FloatCC::LessThanOrEqual,
ComparisonOperator::Greater => FloatCC::GreaterThan,
ComparisonOperator::GreaterOrEqual => FloatCC::GreaterThanOrEqual,
};
let val = self.builder.ins().fcmp(cond, a, b);
self.stack.push(JitValue::Bool(val));
Ok(())
}
_ => Err(JitCompileError::NotSupported),
}
}
Instruction::ExtendedArg
| Instruction::Cache
| Instruction::MakeCell { .. }
| Instruction::CopyFreeVars { .. } => Ok(()),
Instruction::JumpBackward { .. }
| Instruction::JumpBackwardNoInterrupt { .. }
| Instruction::JumpForward { .. } => {
let target = Self::instruction_target(offset, instruction, arg)?
.ok_or(JitCompileError::BadBytecode)?;
let target_block = self.get_or_create_block(target);
self.builder.ins().jump(target_block, &[]);
Ok(())
}
Instruction::LoadConst { consti } => {
let val =
self.prepare_const(bytecode.constants[consti.get(arg)].borrow_constant())?;
self.stack.push(val);
Ok(())
}
Instruction::LoadSmallInt { i } => {
let small_int = i.get(arg) as i64;
let val = self.builder.ins().iconst(types::I64, small_int);
self.stack.push(JitValue::Int(val));
Ok(())
}
Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => {
let local = self.variables[var_num.get(arg)]
.as_ref()
.ok_or(JitCompileError::BadBytecode)?;
self.stack.push(JitValue::from_type_and_value(
local.ty.clone(),
self.builder.use_var(local.var),
));
Ok(())
}
Instruction::LoadFastLoadFast { var_nums }
| Instruction::LoadFastBorrowLoadFastBorrow { var_nums } => {
let oparg = var_nums.get(arg);
let (idx1, idx2) = oparg.indexes();
for idx in [idx1, idx2] {
let local = self.variables[idx]
.as_ref()
.ok_or(JitCompileError::BadBytecode)?;
self.stack.push(JitValue::from_type_and_value(
local.ty.clone(),
self.builder.use_var(local.var),
));
}
Ok(())
}
Instruction::LoadGlobal { namei } => {
let oparg = namei.get(arg);
let name = &bytecode.names[(oparg >> 1) as usize];
if name.as_ref() != bytecode.obj_name.as_ref() {
Err(JitCompileError::NotSupported)
} else {
self.stack.push(JitValue::FuncRef(func_ref));
if (oparg & 1) != 0 {
self.stack.push(JitValue::Null);
}
Ok(())
}
}
Instruction::Nop | Instruction::NotTaken => Ok(()),
Instruction::PopJumpIfFalse { .. } => {
let cond = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let val = self.boolean_val(cond)?;
let then_label = Self::instruction_target(offset, instruction, arg)?
.ok_or(JitCompileError::BadBytecode)?;
let then_block = self.get_or_create_block(then_label);
let else_block = self.builder.create_block();
self.builder
.ins()
.brif(val, else_block, &[], then_block, &[]);
self.builder.switch_to_block(else_block);
Ok(())
}
Instruction::PopJumpIfTrue { .. } => {
let cond = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let val = self.boolean_val(cond)?;
let then_label = Self::instruction_target(offset, instruction, arg)?
.ok_or(JitCompileError::BadBytecode)?;
let then_block = self.get_or_create_block(then_label);
let else_block = self.builder.create_block();
self.builder
.ins()
.brif(val, then_block, &[], else_block, &[]);
self.builder.switch_to_block(else_block);
Ok(())
}
Instruction::PopTop => {
self.stack.pop();
Ok(())
}
Instruction::Resume { .. } => {
Ok(())
}
Instruction::ReturnValue => {
let val = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
self.return_value(val)
}
Instruction::StoreFast { var_num } => {
let val = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
self.store_variable(var_num.get(arg), val)
}
Instruction::StoreFastLoadFast { var_nums } => {
let oparg = var_nums.get(arg);
let (store_idx, load_idx) = oparg.indexes();
let val = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
self.store_variable(store_idx, val)?;
let local = self.variables[load_idx]
.as_ref()
.ok_or(JitCompileError::BadBytecode)?;
self.stack.push(JitValue::from_type_and_value(
local.ty.clone(),
self.builder.use_var(local.var),
));
Ok(())
}
Instruction::StoreFastStoreFast { var_nums } => {
let oparg = var_nums.get(arg);
let (idx1, idx2) = oparg.indexes();
let val1 = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
self.store_variable(idx1, val1)?;
let val2 = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
self.store_variable(idx2, val2)
}
Instruction::Swap { i: index } => {
let len = self.stack.len();
let i = len - 1;
let j = len - 1 - index.get(arg) as usize;
self.stack.swap(i, j);
Ok(())
}
Instruction::ToBool => {
let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let value = self.boolean_val(a)?;
self.stack.push(JitValue::Bool(value));
Ok(())
}
Instruction::UnaryNot => {
let boolean = match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
JitValue::Bool(val) => val,
_ => return Err(JitCompileError::BadBytecode),
};
let not_boolean = self.builder.ins().bxor_imm(boolean, 1);
self.stack.push(JitValue::Bool(not_boolean));
Ok(())
}
Instruction::UnaryNegative => {
match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
JitValue::Int(val) => {
let zero = self.builder.ins().iconst(types::I64, 0);
let out = self.compile_sub(zero, val);
self.stack.push(JitValue::Int(out));
Ok(())
}
_ => Err(JitCompileError::NotSupported),
}
}
Instruction::UnpackSequence { count } => {
let val = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
let elements = match val {
JitValue::Tuple(elements) => elements,
_ => return Err(JitCompileError::NotSupported),
};
if elements.len() != count.get(arg) as usize {
return Err(JitCompileError::NotSupported);
}
self.stack.extend(elements.into_iter().rev());
Ok(())
}
_ => Err(JitCompileError::NotSupported),
}
}
fn compile_sub(&mut self, a: Value, b: Value) -> Value {
let (out, carry) = self.builder.ins().ssub_overflow(a, b);
self.builder.ins().trapnz(carry, TrapCode::INTEGER_OVERFLOW);
out
}
fn dd_from_f64(&mut self, x: f64) -> DDValue {
DDValue {
hi: self.builder.ins().f64const(x),
lo: self.builder.ins().f64const(0.0),
}
}
fn dd_from_value(&mut self, x: Value) -> DDValue {
DDValue {
hi: x,
lo: self.builder.ins().f64const(0.0),
}
}
fn dd_from_parts(&mut self, hi: f64, lo: f64) -> DDValue {
DDValue {
hi: self.builder.ins().f64const(hi),
lo: self.builder.ins().f64const(lo),
}
}
fn dd_to_f64(&mut self, dd: DDValue) -> Value {
self.builder.ins().fadd(dd.hi, dd.lo)
}
fn dd_neg(&mut self, dd: DDValue) -> DDValue {
let zero = self.builder.ins().f64const(0.0);
DDValue {
hi: self.builder.ins().fsub(zero, dd.hi),
lo: self.builder.ins().fsub(zero, dd.lo),
}
}
fn dd_add(&mut self, a: DDValue, b: DDValue) -> DDValue {
let s = self.builder.ins().fadd(a.hi, b.hi);
let t = self.builder.ins().fsub(s, a.hi);
let s_minus_t = self.builder.ins().fsub(s, t);
let part1 = self.builder.ins().fsub(a.hi, s_minus_t);
let part2 = self.builder.ins().fsub(b.hi, t);
let e = self.builder.ins().fadd(part1, part2);
let lo = self.builder.ins().fadd(a.lo, b.lo);
let lo_sum = self.builder.ins().fadd(lo, e);
let hi_new = self.builder.ins().fadd(s, lo_sum);
let hi_new_minus_s = self.builder.ins().fsub(hi_new, s);
let lo_new = self.builder.ins().fsub(lo_sum, hi_new_minus_s);
DDValue {
hi: hi_new,
lo: lo_new,
}
}
fn dd_sub(&mut self, a: DDValue, b: DDValue) -> DDValue {
let neg_b = self.dd_neg(b);
self.dd_add(a, neg_b)
}
fn dd_mul(&mut self, a: DDValue, b: DDValue) -> DDValue {
let p = self.builder.ins().fmul(a.hi, b.hi);
let zero = self.builder.ins().f64const(0.0);
let neg_p = self.builder.ins().fsub(zero, p);
let err = self.builder.ins().fma(a.hi, b.hi, neg_p);
let a_hi_b_lo = self.builder.ins().fmul(a.hi, b.lo);
let a_lo_b_hi = self.builder.ins().fmul(a.lo, b.hi);
let cross = self.builder.ins().fadd(a_hi_b_lo, a_lo_b_hi);
let s = self.builder.ins().fadd(p, cross);
let t = self.builder.ins().fsub(s, p);
let s_minus_t = self.builder.ins().fsub(s, t);
let part1 = self.builder.ins().fsub(p, s_minus_t);
let part2 = self.builder.ins().fsub(cross, t);
let e = self.builder.ins().fadd(part1, part2);
let a_lo_b_lo = self.builder.ins().fmul(a.lo, b.lo);
let err_plus_e = self.builder.ins().fadd(err, e);
let lo_sum = self.builder.ins().fadd(err_plus_e, a_lo_b_lo);
let hi_new = self.builder.ins().fadd(s, lo_sum);
let hi_new_minus_s = self.builder.ins().fsub(hi_new, s);
let lo_new = self.builder.ins().fsub(lo_sum, hi_new_minus_s);
DDValue {
hi: hi_new,
lo: lo_new,
}
}
fn dd_mul_f64(&mut self, a: DDValue, b: Value) -> DDValue {
let p = self.builder.ins().fmul(a.hi, b);
let zero = self.builder.ins().f64const(0.0);
let neg_p = self.builder.ins().fsub(zero, p);
let err = self.builder.ins().fma(a.hi, b, neg_p);
let cross = self.builder.ins().fmul(a.lo, b);
let s = self.builder.ins().fadd(p, cross);
let t = self.builder.ins().fsub(s, p);
let s_minus_t = self.builder.ins().fsub(s, t);
let part1 = self.builder.ins().fsub(p, s_minus_t);
let part2 = self.builder.ins().fsub(cross, t);
let e = self.builder.ins().fadd(part1, part2);
let lo_sum = self.builder.ins().fadd(err, e);
let hi_new = self.builder.ins().fadd(s, lo_sum);
let hi_new_minus_s = self.builder.ins().fsub(hi_new, s);
let lo_new = self.builder.ins().fsub(lo_sum, hi_new_minus_s);
DDValue {
hi: hi_new,
lo: lo_new,
}
}
fn dd_scale(&mut self, dd: DDValue, factor: Value) -> DDValue {
DDValue {
hi: self.builder.ins().fmul(dd.hi, factor),
lo: self.builder.ins().fmul(dd.lo, factor),
}
}
fn dd_ln_1p_series(&mut self, f: Value) -> DDValue {
let f_dd = self.dd_from_value(f);
let mut sum = f_dd.clone();
let mut term = f_dd;
let mut sign = -1.0_f64;
let range = 1000;
for i in 2..=range {
term = self.dd_mul_f64(term, f);
let inv_i = 1.0 / (i as f64);
let c_inv_i = self.builder.ins().f64const(inv_i);
let term_div = self.dd_mul_f64(term.clone(), c_inv_i);
let dd_sign = self.dd_from_f64(sign);
let to_add = self.dd_mul(dd_sign, term_div);
sum = self.dd_add(sum, to_add);
sign = -sign;
}
sum
}
fn dd_ln(&mut self, x: Value) -> DDValue {
let dd_nan = self.dd_from_f64(f64::NAN);
let zero_f64 = self.builder.ins().f64const(0.0);
let cmp_le = self
.builder
.ins()
.fcmp(FloatCC::LessThanOrEqual, x, zero_f64);
let cmp_nan = self.builder.ins().fcmp(FloatCC::Unordered, x, x);
let need_nan = self.builder.ins().bor(cmp_le, cmp_nan);
let bits = self.builder.ins().bitcast(types::I64, MemFlags::new(), x);
let shift_52 = self.builder.ins().ushr_imm(bits, 52);
let exponent_mask = self.builder.ins().iconst(types::I64, 0x7FF);
let exponent = self.builder.ins().band(shift_52, exponent_mask);
let bias = self.builder.ins().iconst(types::I64, 1023);
let k_i64 = self.builder.ins().isub(exponent, bias);
let fraction_mask = self.builder.ins().iconst(types::I64, 0x000F_FFFF_FFFF_FFFF);
let fraction_part = self.builder.ins().band(bits, fraction_mask);
let implicit_one = self.builder.ins().iconst(types::I64, 1 << 52);
let zero_exp = self.builder.ins().icmp_imm(IntCC::Equal, exponent, 0);
let frac_one_bor = self.builder.ins().bor(fraction_part, implicit_one);
let fraction_with_leading_one = self.builder.ins().select(
zero_exp,
fraction_part, frac_one_bor,
);
let new_exp = self.builder.ins().iconst(types::I64, 0x3FF0_0000_0000_0000);
let fraction_bits = self.builder.ins().bor(fraction_with_leading_one, new_exp);
let m = self
.builder
.ins()
.bitcast(types::F64, MemFlags::new(), fraction_bits);
let one_f64 = self.builder.ins().f64const(1.0);
let f_val = self.builder.ins().fsub(m, one_f64);
let dd_ln_m = self.dd_ln_1p_series(f_val);
let ln2_dd = self.dd_from_parts(
f64::from_bits(0x3fe62e42fefa39ef),
f64::from_bits(0x3c7abc9e3b39803f),
);
let k_f64 = self.builder.ins().fcvt_from_sint(types::F64, k_i64);
let dd_ln2_k = self.dd_mul_f64(ln2_dd, k_f64);
let normal_result = self.dd_add(dd_ln_m, dd_ln2_k);
let final_hi = self
.builder
.ins()
.select(need_nan, dd_nan.hi, normal_result.hi);
let final_lo = self
.builder
.ins()
.select(need_nan, dd_nan.lo, normal_result.lo);
DDValue {
hi: final_hi,
lo: final_lo,
}
}
fn dd_exp(&mut self, dd: DDValue) -> DDValue {
let x = self.dd_to_f64(dd.clone());
let ln2_f64 = self
.builder
.ins()
.f64const(f64::from_bits(0x3fe62e42fefa39ef));
let div = self.builder.ins().fdiv(x, ln2_f64);
let half = self.builder.ins().f64const(0.5);
let div_plus_half = self.builder.ins().fadd(div, half);
let k = self.builder.ins().fcvt_to_sint(types::I64, div_plus_half);
let max_k = self.builder.ins().iconst(types::I64, 1023);
let is_overflow = self.builder.ins().icmp(IntCC::SignedGreaterThan, k, max_k);
let inf = self.builder.ins().f64const(f64::INFINITY);
let zero = self.builder.ins().f64const(0.0);
let ln2_dd = self.dd_from_parts(
f64::from_bits(0x3fe62e42fefa39ef),
f64::from_bits(0x3c7abc9e3b39803f),
);
let k_f64 = self.builder.ins().fcvt_from_sint(types::F64, k);
let k_ln2 = self.dd_mul_f64(ln2_dd, k_f64);
let r = self.dd_sub(dd, k_ln2);
let mut sum = self.dd_from_f64(1.0); let mut term = self.dd_from_f64(1.0); let n_terms = 1000;
for i in 1..=n_terms {
term = self.dd_mul(term, r.clone());
let inv = 1.0 / (i as f64);
let inv_const = self.builder.ins().f64const(inv);
term = self.dd_mul_f64(term, inv_const);
sum = self.dd_add(sum, term.clone());
}
let bias = self.builder.ins().iconst(types::I64, 1023);
let k_plus_bias = self.builder.ins().iadd(k, bias);
let shift_count = self.builder.ins().iconst(types::I64, 52);
let shifted = self.builder.ins().ishl(k_plus_bias, shift_count);
let two_to_k = self
.builder
.ins()
.bitcast(types::F64, MemFlags::new(), shifted);
let result = self.dd_scale(sum, two_to_k);
let final_hi = self.builder.ins().select(is_overflow, inf, result.hi);
let final_lo = self.builder.ins().select(is_overflow, zero, result.lo);
DDValue {
hi: final_hi,
lo: final_lo,
}
}
fn compile_fpow(&mut self, a: Value, b: Value) -> Value {
let f64_ty = types::F64;
let i64_ty = types::I64;
let zero_f = self.builder.ins().f64const(0.0);
let one_f = self.builder.ins().f64const(1.0);
let nan_f = self.builder.ins().f64const(f64::NAN);
let inf_f = self.builder.ins().f64const(f64::INFINITY);
let neg_inf_f = self.builder.ins().f64const(f64::NEG_INFINITY);
let merge_block = self.builder.create_block();
self.builder.append_block_param(merge_block, f64_ty);
let cmp_b_zero = self.builder.ins().fcmp(FloatCC::Equal, b, zero_f);
let b_zero_block = self.builder.create_block();
let continue_block = self.builder.create_block();
self.builder
.ins()
.brif(cmp_b_zero, b_zero_block, &[], continue_block, &[]);
self.builder.switch_to_block(b_zero_block);
self.builder.ins().jump(merge_block, &[one_f.into()]);
self.builder.switch_to_block(continue_block);
let cmp_b_nan = self.builder.ins().fcmp(FloatCC::Unordered, b, b);
let b_nan_block = self.builder.create_block();
let continue_block2 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_b_nan, b_nan_block, &[], continue_block2, &[]);
self.builder.switch_to_block(b_nan_block);
self.builder.ins().jump(merge_block, &[nan_f.into()]);
self.builder.switch_to_block(continue_block2);
let cmp_a_zero = self.builder.ins().fcmp(FloatCC::Equal, a, zero_f);
let a_zero_block = self.builder.create_block();
let continue_block3 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_a_zero, a_zero_block, &[], continue_block3, &[]);
self.builder.switch_to_block(a_zero_block);
self.builder.ins().jump(merge_block, &[zero_f.into()]);
self.builder.switch_to_block(continue_block3);
let cmp_a_nan = self.builder.ins().fcmp(FloatCC::Unordered, a, a);
let a_nan_block = self.builder.create_block();
let continue_block4 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_a_nan, a_nan_block, &[], continue_block4, &[]);
self.builder.switch_to_block(a_nan_block);
self.builder.ins().jump(merge_block, &[nan_f.into()]);
self.builder.switch_to_block(continue_block4);
let cmp_b_inf = self.builder.ins().fcmp(FloatCC::Equal, b, inf_f);
let b_inf_block = self.builder.create_block();
let continue_block5 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_b_inf, b_inf_block, &[], continue_block5, &[]);
self.builder.switch_to_block(b_inf_block);
self.builder.ins().jump(merge_block, &[inf_f.into()]);
self.builder.switch_to_block(continue_block5);
let cmp_b_neg_inf = self.builder.ins().fcmp(FloatCC::Equal, b, neg_inf_f);
let b_neg_inf_block = self.builder.create_block();
let continue_block6 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_b_neg_inf, b_neg_inf_block, &[], continue_block6, &[]);
self.builder.switch_to_block(b_neg_inf_block);
self.builder.ins().jump(merge_block, &[zero_f.into()]);
self.builder.switch_to_block(continue_block6);
let cmp_a_inf = self.builder.ins().fcmp(FloatCC::Equal, a, inf_f);
let a_inf_block = self.builder.create_block();
let continue_block7 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_a_inf, a_inf_block, &[], continue_block7, &[]);
self.builder.switch_to_block(a_inf_block);
self.builder.ins().jump(merge_block, &[inf_f.into()]);
self.builder.switch_to_block(continue_block7);
let cmp_a_neg_inf = self.builder.ins().fcmp(FloatCC::Equal, a, neg_inf_f);
let a_neg_inf_block = self.builder.create_block();
let continue_block8 = self.builder.create_block();
self.builder
.ins()
.brif(cmp_a_neg_inf, a_neg_inf_block, &[], continue_block8, &[]);
self.builder.switch_to_block(a_neg_inf_block);
let b_floor = self.builder.ins().floor(b);
let cmp_int = self.builder.ins().fcmp(FloatCC::Equal, b_floor, b);
let domain_error_blk = self.builder.create_block();
let continue_neg_inf = self.builder.create_block();
self.builder
.ins()
.brif(cmp_int, continue_neg_inf, &[], domain_error_blk, &[]);
self.builder.switch_to_block(domain_error_blk);
self.builder.ins().jump(merge_block, &[nan_f.into()]);
self.builder.switch_to_block(continue_neg_inf);
let b_i64 = self.builder.ins().fcvt_to_sint(i64_ty, b_floor);
let one_i = self.builder.ins().iconst(i64_ty, 1);
let remainder = self.builder.ins().band(b_i64, one_i);
let zero_i = self.builder.ins().iconst(i64_ty, 0);
let is_odd = self.builder.ins().icmp(IntCC::NotEqual, remainder, zero_i);
let odd_block = self.builder.create_block();
let even_block = self.builder.create_block();
self.builder.append_block_param(odd_block, f64_ty);
self.builder.append_block_param(even_block, f64_ty);
self.builder.ins().brif(
is_odd,
odd_block,
&[neg_inf_f.into()],
even_block,
&[inf_f.into()],
);
self.builder.switch_to_block(odd_block);
let phi_neg_inf = self.builder.block_params(odd_block)[0];
self.builder.ins().jump(merge_block, &[phi_neg_inf.into()]);
self.builder.switch_to_block(even_block);
let phi_inf = self.builder.block_params(even_block)[0];
self.builder.ins().jump(merge_block, &[phi_inf.into()]);
self.builder.switch_to_block(continue_block8);
let cmp_lt = self.builder.ins().fcmp(FloatCC::LessThan, a, zero_f);
let a_neg_block = self.builder.create_block();
let a_pos_block = self.builder.create_block();
self.builder
.ins()
.brif(cmp_lt, a_neg_block, &[], a_pos_block, &[]);
self.builder.switch_to_block(a_pos_block);
let ln_a_dd = self.dd_ln(a);
let b_dd = self.dd_from_value(b);
let product_dd = self.dd_mul(ln_a_dd, b_dd);
let exp_dd = self.dd_exp(product_dd);
let pos_res = self.dd_to_f64(exp_dd);
self.builder.ins().jump(merge_block, &[pos_res.into()]);
self.builder.switch_to_block(a_neg_block);
let b_floor = self.builder.ins().floor(b);
let cmp_int = self.builder.ins().fcmp(FloatCC::Equal, b_floor, b);
let neg_int_block = self.builder.create_block();
let domain_error_blk = self.builder.create_block();
self.builder
.ins()
.brif(cmp_int, neg_int_block, &[], domain_error_blk, &[]);
self.builder.switch_to_block(domain_error_blk);
self.builder.ins().jump(merge_block, &[nan_f.into()]);
self.builder.switch_to_block(neg_int_block);
let abs_a = self.builder.ins().fabs(a);
let ln_abs_dd = self.dd_ln(abs_a);
let b_dd = self.dd_from_value(b);
let product_dd = self.dd_mul(ln_abs_dd, b_dd);
let exp_dd = self.dd_exp(product_dd);
let mag_val = self.dd_to_f64(exp_dd);
let b_i64 = self.builder.ins().fcvt_to_sint(i64_ty, b_floor);
let one_i = self.builder.ins().iconst(i64_ty, 1);
let remainder = self.builder.ins().band(b_i64, one_i);
let zero_i = self.builder.ins().iconst(i64_ty, 0);
let is_odd = self.builder.ins().icmp(IntCC::NotEqual, remainder, zero_i);
let odd_block = self.builder.create_block();
let even_block = self.builder.create_block();
self.builder.append_block_param(odd_block, f64_ty);
self.builder.append_block_param(even_block, f64_ty);
self.builder.ins().brif(
is_odd,
odd_block,
&[mag_val.into()],
even_block,
&[mag_val.into()],
);
self.builder.switch_to_block(odd_block);
let phi_mag_val = self.builder.block_params(odd_block)[0];
let neg_val = self.builder.ins().fneg(phi_mag_val);
self.builder.ins().jump(merge_block, &[neg_val.into()]);
self.builder.switch_to_block(even_block);
let phi_mag_val_even = self.builder.block_params(even_block)[0];
self.builder
.ins()
.jump(merge_block, &[phi_mag_val_even.into()]);
self.builder.switch_to_block(merge_block);
self.builder.block_params(merge_block)[0]
}
fn compile_ipow(&mut self, a: Value, b: Value) -> Value {
let zero = self.builder.ins().iconst(types::I64, 0);
let one_i64 = self.builder.ins().iconst(types::I64, 1);
let check_negative = self.builder.create_block();
let handle_negative = self.builder.create_block();
let loop_block = self.builder.create_block();
let continue_block = self.builder.create_block();
let exit_block = self.builder.create_block();
self.builder.append_block_param(check_negative, types::I64); self.builder.append_block_param(check_negative, types::I64);
self.builder.append_block_param(handle_negative, types::I64); self.builder.append_block_param(handle_negative, types::I64);
self.builder.append_block_param(loop_block, types::I64); self.builder.append_block_param(loop_block, types::I64); self.builder.append_block_param(loop_block, types::I64);
self.builder.append_block_param(exit_block, types::I64);
self.builder.append_block_param(continue_block, types::I64); self.builder.append_block_param(continue_block, types::I64); self.builder.append_block_param(continue_block, types::I64);
self.builder
.ins()
.jump(check_negative, &[b.into(), a.into()]);
self.builder.switch_to_block(check_negative);
let params = self.builder.block_params(check_negative);
let exp_check = params[0];
let base_check = params[1];
let is_negative = self
.builder
.ins()
.icmp(IntCC::SignedLessThan, exp_check, zero);
self.builder.ins().brif(
is_negative,
handle_negative,
&[exp_check.into(), base_check.into()],
loop_block,
&[exp_check.into(), one_i64.into(), base_check.into()],
);
self.builder.switch_to_block(handle_negative);
self.builder.ins().jump(exit_block, &[zero.into()]);
self.builder.switch_to_block(loop_block);
let params = self.builder.block_params(loop_block);
let exp_phi = params[0];
let result_phi = params[1];
let base_phi = params[2];
let is_zero = self.builder.ins().icmp(IntCC::Equal, exp_phi, zero);
self.builder.ins().brif(
is_zero,
exit_block,
&[result_phi.into()],
continue_block,
&[exp_phi.into(), result_phi.into(), base_phi.into()],
);
self.builder.switch_to_block(continue_block);
let params = self.builder.block_params(continue_block);
let exp_phi = params[0];
let result_phi = params[1];
let base_phi = params[2];
let is_odd = self.builder.ins().band_imm(exp_phi, 1);
let is_odd = self.builder.ins().icmp_imm(IntCC::Equal, is_odd, 1);
let mul_result = self.builder.ins().imul(result_phi, base_phi);
let new_result = self.builder.ins().select(is_odd, mul_result, result_phi);
let squared_base = self.builder.ins().imul(base_phi, base_phi);
let new_exp = self.builder.ins().sshr_imm(exp_phi, 1);
self.builder.ins().jump(
loop_block,
&[new_exp.into(), new_result.into(), squared_base.into()],
);
self.builder.switch_to_block(exit_block);
let res = self.builder.block_params(exit_block)[0];
self.builder.seal_block(check_negative);
self.builder.seal_block(handle_negative);
self.builder.seal_block(loop_block);
self.builder.seal_block(continue_block);
self.builder.seal_block(exit_block);
res
}
}