mod instructions;
extern crate alloc;
use alloc::fmt;
use core::mem::ManuallyDrop;
use cranelift::prelude::*;
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{FuncId, Linkage, Module, ModuleError};
use instructions::FunctionCompiler;
use rustpython_compiler_core::bytecode;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum JitCompileError {
#[error("function can't be jitted")]
NotSupported,
#[error("bad bytecode")]
BadBytecode,
#[error("error while compiling to machine code: {0}")]
CraneliftError(Box<ModuleError>),
}
impl From<ModuleError> for JitCompileError {
fn from(err: ModuleError) -> Self {
Self::CraneliftError(Box::new(err))
}
}
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
#[non_exhaustive]
pub enum JitArgumentError {
#[error("argument is of wrong type")]
ArgumentTypeMismatch,
#[error("wrong number of arguments")]
WrongNumberOfArguments,
}
struct Jit {
builder_context: FunctionBuilderContext,
ctx: codegen::Context,
module: JITModule,
}
impl Jit {
fn new() -> Self {
let builder = JITBuilder::new(cranelift_module::default_libcall_names())
.expect("Failed to build JITBuilder");
let module = JITModule::new(builder);
Self {
builder_context: FunctionBuilderContext::new(),
ctx: module.make_context(),
module,
}
}
fn build_function<C: bytecode::Constant>(
&mut self,
bytecode: &bytecode::CodeObject<C>,
args: &[JitType],
ret: Option<JitType>,
) -> Result<(FuncId, JitSig), JitCompileError> {
for arg in args {
self.ctx
.func
.signature
.params
.push(AbiParam::new(arg.to_cranelift()));
}
if ret.is_some() {
self.ctx
.func
.signature
.returns
.push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
}
let id = self.module.declare_function(
&format!("jit_{}", bytecode.obj_name.as_ref()),
Linkage::Export,
&self.ctx.func.signature,
)?;
let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
let sig = {
let mut compiler = FunctionCompiler::new(
&mut builder,
bytecode.varnames.len(),
args,
ret,
entry_block,
);
compiler.compile(func_ref, bytecode)?;
compiler.sig
};
builder.seal_all_blocks();
builder.finalize();
self.module.define_function(id, &mut self.ctx)?;
self.module.clear_context(&mut self.ctx);
Ok((id, sig))
}
}
pub fn compile<C: bytecode::Constant>(
bytecode: &bytecode::CodeObject<C>,
args: &[JitType],
ret: Option<JitType>,
) -> Result<CompiledCode, JitCompileError> {
let mut jit = Jit::new();
let (id, sig) = jit.build_function(bytecode, args, ret)?;
jit.module.finalize_definitions()?;
let code = jit.module.get_finalized_function(id);
Ok(CompiledCode {
sig,
code,
module: ManuallyDrop::new(jit.module),
})
}
pub struct CompiledCode {
sig: JitSig,
code: *const u8,
module: ManuallyDrop<JITModule>,
}
impl CompiledCode {
pub fn args_builder(&self) -> ArgsBuilder<'_> {
ArgsBuilder::new(self)
}
pub fn invoke(&self, args: &[AbiValue]) -> Result<Option<AbiValue>, JitArgumentError> {
if self.sig.args.len() != args.len() {
return Err(JitArgumentError::WrongNumberOfArguments);
}
let cif_args = self
.sig
.args
.iter()
.zip(args.iter())
.map(|(ty, val)| type_check(ty, val).map(|_| val))
.map(|v| v.map(AbiValue::to_libffi_arg))
.collect::<Result<Vec<_>, _>>()?;
Ok(unsafe { self.invoke_raw(&cif_args) })
}
unsafe fn invoke_raw(&self, cif_args: &[libffi::middle::Arg<'_>]) -> Option<AbiValue> {
unsafe {
let cif = self.sig.to_cif();
let value = cif.call::<UnTypedAbiValue>(
libffi::middle::CodePtr::from_ptr(self.code as *const _),
cif_args,
);
self.sig.ret.as_ref().map(|ty| value.to_typed(ty))
}
}
}
struct JitSig {
args: Vec<JitType>,
ret: Option<JitType>,
}
impl JitSig {
fn to_cif(&self) -> libffi::middle::Cif {
let ret = match self.ret {
Some(ref ty) => ty.to_libffi(),
None => libffi::middle::Type::void(),
};
libffi::middle::Cif::new(self.args.iter().map(JitType::to_libffi), ret)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum JitType {
Int,
Float,
Bool,
}
impl JitType {
fn to_cranelift(&self) -> types::Type {
match self {
Self::Int => types::I64,
Self::Float => types::F64,
Self::Bool => types::I8,
}
}
fn to_libffi(&self) -> libffi::middle::Type {
match self {
Self::Int => libffi::middle::Type::i64(),
Self::Float => libffi::middle::Type::f64(),
Self::Bool => libffi::middle::Type::u8(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum AbiValue {
Float(f64),
Int(i64),
Bool(bool),
}
impl AbiValue {
fn to_libffi_arg(&self) -> libffi::middle::Arg<'_> {
match self {
AbiValue::Int(i) => libffi::middle::Arg::new(i),
AbiValue::Float(f) => libffi::middle::Arg::new(f),
AbiValue::Bool(b) => libffi::middle::Arg::new(b),
}
}
}
impl From<i64> for AbiValue {
fn from(i: i64) -> Self {
AbiValue::Int(i)
}
}
impl From<f64> for AbiValue {
fn from(f: f64) -> Self {
AbiValue::Float(f)
}
}
impl From<bool> for AbiValue {
fn from(b: bool) -> Self {
AbiValue::Bool(b)
}
}
impl TryFrom<AbiValue> for i64 {
type Error = ();
fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
match value {
AbiValue::Int(i) => Ok(i),
_ => Err(()),
}
}
}
impl TryFrom<AbiValue> for f64 {
type Error = ();
fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
match value {
AbiValue::Float(f) => Ok(f),
_ => Err(()),
}
}
}
impl TryFrom<AbiValue> for bool {
type Error = ();
fn try_from(value: AbiValue) -> Result<Self, Self::Error> {
match value {
AbiValue::Bool(b) => Ok(b),
_ => Err(()),
}
}
}
fn type_check(ty: &JitType, val: &AbiValue) -> Result<(), JitArgumentError> {
match (ty, val) {
(JitType::Int, AbiValue::Int(_))
| (JitType::Float, AbiValue::Float(_))
| (JitType::Bool, AbiValue::Bool(_)) => Ok(()),
_ => Err(JitArgumentError::ArgumentTypeMismatch),
}
}
#[derive(Copy, Clone)]
union UnTypedAbiValue {
float: f64,
int: i64,
boolean: u8,
_void: (),
}
impl UnTypedAbiValue {
unsafe fn to_typed(self, ty: &JitType) -> AbiValue {
unsafe {
match ty {
JitType::Int => AbiValue::Int(self.int),
JitType::Float => AbiValue::Float(self.float),
JitType::Bool => AbiValue::Bool(self.boolean != 0),
}
}
}
}
unsafe impl Send for CompiledCode {}
unsafe impl Sync for CompiledCode {}
impl Drop for CompiledCode {
fn drop(&mut self) {
unsafe { ManuallyDrop::take(&mut self.module).free_memory() }
}
}
impl fmt::Debug for CompiledCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("[compiled code]")
}
}
pub struct ArgsBuilder<'a> {
values: Vec<Option<AbiValue>>,
code: &'a CompiledCode,
}
impl<'a> ArgsBuilder<'a> {
fn new(code: &'a CompiledCode) -> ArgsBuilder<'a> {
ArgsBuilder {
values: vec![None; code.sig.args.len()],
code,
}
}
pub fn set(&mut self, idx: usize, value: AbiValue) -> Result<(), JitArgumentError> {
type_check(&self.code.sig.args[idx], &value).map(|_| {
self.values[idx] = Some(value);
})
}
pub fn is_set(&self, idx: usize) -> bool {
self.values[idx].is_some()
}
pub fn into_args(self) -> Option<Args<'a>> {
if self.values.iter().any(|v| v.is_none()) {
return None;
}
Some(Args {
values: self.values.into_iter().map(|v| v.unwrap()).collect(),
code: self.code,
})
}
}
pub struct Args<'a> {
values: Vec<AbiValue>,
code: &'a CompiledCode,
}
impl Args<'_> {
pub fn invoke(&self) -> Option<AbiValue> {
let cif_args: Vec<_> = self.values.iter().map(AbiValue::to_libffi_arg).collect();
unsafe { self.code.invoke_raw(&cif_args) }
}
}