use crate::ast::{ASTBinOp, ASTFun, ASTNode};
use crate::context::{DSPFunction, DSPNodeContext, DSPNodeSigBit, DSPNodeType, DSPNodeTypeLibrary};
use cranelift::prelude::types::{F64, I32};
use cranelift::prelude::InstBuilder;
use cranelift::prelude::*;
use cranelift_codegen::ir::immediates::Offset32;
use cranelift_codegen::settings::{self, Configurable};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::default_libcall_names;
use cranelift_module::{FuncId, Linkage, Module};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
pub struct JIT {
builder_context: FunctionBuilderContext,
ctx: codegen::Context,
module: Option<JITModule>,
dsp_lib: Rc<RefCell<DSPNodeTypeLibrary>>,
dsp_ctx: Rc<RefCell<DSPNodeContext>>,
}
impl JIT {
pub fn new(
dsp_lib: Rc<RefCell<DSPNodeTypeLibrary>>,
dsp_ctx: Rc<RefCell<DSPNodeContext>>,
) -> Self {
let mut flag_builder = settings::builder();
flag_builder
.set("use_colocated_libcalls", "false")
.expect("Setting 'use_colocated_libcalls' works");
flag_builder.set("is_pic", "false").expect("Setting 'is_pic' works");
let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
panic!("host machine is not supported: {}", msg);
});
let isa = isa_builder
.finish(settings::Flags::new(flag_builder))
.expect("ISA Builder finish works");
let mut builder = JITBuilder::with_isa(isa, default_libcall_names());
dsp_lib
.borrow()
.for_each(|typ| -> Result<(), JITCompileError> {
builder.symbol(typ.name(), typ.function_ptr());
Ok(())
})
.expect("symbol adding works");
let module = JITModule::new(builder);
Self {
builder_context: FunctionBuilderContext::new(),
ctx: module.make_context(),
module: Some(module),
dsp_lib,
dsp_ctx,
}
}
pub fn compile(mut self, prog: ASTFun) -> Result<Box<DSPFunction>, JITCompileError> {
let module = self.module.as_mut().expect("Module still loaded");
let ptr_type = module.target_config().pointer_type();
for param_idx in 0..prog.param_count() {
if prog.param_is_ref(param_idx) {
self.ctx.func.signature.params.push(AbiParam::new(ptr_type));
} else {
self.ctx.func.signature.params.push(AbiParam::new(F64));
};
}
self.ctx.func.signature.returns.push(AbiParam::new(F64));
let id = module
.declare_function("dsp", Linkage::Export, &self.ctx.func.signature)
.map_err(|e| JITCompileError::DeclareTopFunError(e.to_string()))?;
self.ctx.func.name = ExternalName::user(0, id.as_u32());
self.translate(prog)?;
let mut module = self.module.take().expect("Module still loaded");
module
.define_function(id, &mut self.ctx)
.map_err(|e| JITCompileError::DefineTopFunError(e.to_string()))?;
module.clear_context(&mut self.ctx);
module.finalize_definitions();
let code = module.get_finalized_function(id);
let dsp_fun = self
.dsp_ctx
.borrow_mut()
.finalize_dsp_function(code, module)
.expect("DSPFunction present in DSPNodeContext.");
Ok(dsp_fun)
}
fn translate(&mut self, fun: ASTFun) -> Result<(), JITCompileError> {
let builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let module = self.module.as_mut().expect("Module still loaded");
let dsp_lib = self.dsp_lib.clone();
let dsp_lib = dsp_lib.borrow();
let dsp_ctx = self.dsp_ctx.clone();
let mut dsp_ctx = dsp_ctx.borrow_mut();
let mut trans = DSPFunctionTranslator::new(&mut *dsp_ctx, &*dsp_lib, builder, module);
trans.register_functions()?;
let ret = trans.translate(fun)?;
Ok(ret)
}
}
pub(crate) struct DSPFunctionTranslator<'a, 'b, 'c> {
dsp_ctx: &'c mut DSPNodeContext,
dsp_lib: &'b DSPNodeTypeLibrary,
builder: FunctionBuilder<'a>,
variables: HashMap<String, Variable>,
var_index: usize,
module: &'a mut JITModule,
dsp_node_functions: HashMap<String, (Rc<dyn DSPNodeType>, FuncId)>,
ptr_w: u32,
}
#[derive(Debug, Clone)]
pub enum JITCompileError {
BadDefinedParams,
UnknownFunction(String),
UndefinedVariable(String),
InvalidReturnValueAccess(String),
DeclareTopFunError(String),
DefineTopFunError(String),
UndefinedDSPNode(String),
NotEnoughArgsInCall(String, u64),
NodeStateError(String, u64),
}
impl<'a, 'b, 'c> DSPFunctionTranslator<'a, 'b, 'c> {
pub fn new(
dsp_ctx: &'c mut DSPNodeContext,
dsp_lib: &'b DSPNodeTypeLibrary,
builder: FunctionBuilder<'a>,
module: &'a mut JITModule,
) -> Self {
dsp_ctx.init_dsp_function();
Self {
dsp_ctx,
dsp_lib,
var_index: 0,
variables: HashMap::new(),
builder,
module,
dsp_node_functions: HashMap::new(),
ptr_w: 8,
}
}
pub fn register_functions(&mut self) -> Result<(), JITCompileError> {
let ptr_type = self.module.target_config().pointer_type();
let mut dsp_node_functions = HashMap::new();
self.dsp_lib.for_each(|typ| {
let mut sig = self.module.make_signature();
let mut i = 0;
while let Some(bit) = typ.signature(i) {
match bit {
DSPNodeSigBit::Value => {
sig.params.push(AbiParam::new(F64));
}
DSPNodeSigBit::DSPStatePtr
| DSPNodeSigBit::NodeStatePtr
| DSPNodeSigBit::MultReturnPtr => {
sig.params.push(AbiParam::new(ptr_type));
}
}
i += 1;
}
if typ.has_return_value() {
sig.returns.push(AbiParam::new(F64));
}
let func_id = self
.module
.declare_function(typ.name(), cranelift_module::Linkage::Import, &sig)
.map_err(|e| JITCompileError::DeclareTopFunError(e.to_string()))?;
dsp_node_functions.insert(typ.name().to_string(), (typ.clone(), func_id));
Ok(())
})?;
self.dsp_node_functions = dsp_node_functions;
Ok(())
}
fn declare_variable(&mut self, typ: types::Type, name: &str) -> Variable {
let var = Variable::new(self.var_index);
if !self.variables.contains_key(name) {
self.variables.insert(name.into(), var);
self.builder.declare_var(var, typ);
self.var_index += 1;
}
var
}
fn translate(&mut self, fun: ASTFun) -> Result<(), JITCompileError> {
let ptr_type = self.module.target_config().pointer_type();
self.ptr_w = ptr_type.bytes();
let entry_block = self.builder.create_block();
self.builder.append_block_params_for_function_params(entry_block);
self.builder.switch_to_block(entry_block);
self.builder.seal_block(entry_block);
self.variables.clear();
for param_idx in 0..fun.param_count() {
let val = self.builder.block_params(entry_block)[param_idx];
match fun.param_name(param_idx) {
Some(param_name) => {
let var = if fun.param_is_ref(param_idx) {
self.declare_variable(ptr_type, param_name)
} else {
self.declare_variable(F64, param_name)
};
self.builder.def_var(var, val);
}
None => {
return Err(JITCompileError::BadDefinedParams);
}
}
}
for local_name in fun.local_variables().iter() {
let zero = self.builder.ins().f64const(0.0);
let var = self.declare_variable(F64, local_name);
self.builder.def_var(var, zero);
}
let v = self.compile(fun.ast_ref())?;
self.builder.ins().return_(&[v]);
self.builder.finalize();
Ok(())
}
fn ins_b_to_f64(&mut self, v: Value) -> Value {
let bint = self.builder.ins().bint(I32, v);
self.builder.ins().fcvt_from_uint(F64, bint)
}
fn compile(&mut self, ast: &Box<ASTNode>) -> Result<Value, JITCompileError> {
match ast.as_ref() {
ASTNode::Lit(v) => Ok(self.builder.ins().f64const(*v)),
ASTNode::Var(name) => {
if name.chars().next() == Some('&') {
let variable = self
.variables
.get(name)
.ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
let ptr = self.builder.use_var(*variable);
Ok(self.builder.ins().load(F64, MemFlags::new(), ptr, 0))
} else if name.chars().next() == Some('*') {
let pv_index = self
.dsp_ctx
.get_persistent_variable_index(name)
.or_else(|_| Err(JITCompileError::UndefinedVariable(name.to_string())))?;
let persistent_vars = self
.variables
.get("&pv")
.ok_or_else(|| JITCompileError::UndefinedVariable("&pv".to_string()))?;
let pvs = self.builder.use_var(*persistent_vars);
let pers_value = self.builder.ins().load(
F64,
MemFlags::new(),
pvs,
Offset32::new(pv_index as i32 * F64.bytes() as i32),
);
Ok(pers_value)
} else if name.chars().next() == Some('%') {
if name.len() > 2 {
return Err(JITCompileError::InvalidReturnValueAccess(name.to_string()));
}
let offs: i32 =
match name.chars().nth(1) {
Some('1') => 0,
Some('2') => 1,
Some('3') => 2,
Some('4') => 3,
Some('5') => 4,
_ => {
return Err(JITCompileError::InvalidReturnValueAccess(name.to_string()));
},
};
let return_vals = self
.variables
.get("&rv")
.ok_or_else(|| JITCompileError::UndefinedVariable("&rv".to_string()))?;
let rvs = self.builder.use_var(*return_vals);
let ret_value = self.builder.ins().load(
F64,
MemFlags::new(),
rvs,
Offset32::new(offs * F64.bytes() as i32),
);
Ok(ret_value)
} else {
let variable = self
.variables
.get(name)
.ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
Ok(self.builder.use_var(*variable))
}
}
ASTNode::Assign(name, ast) => {
let value = self.compile(ast)?;
if name.chars().next() == Some('&') {
let variable = self
.variables
.get(name)
.ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
let ptr = self.builder.use_var(*variable);
self.builder.ins().store(MemFlags::new(), value, ptr, 0);
} else if name.chars().next() == Some('*') {
let pv_index = self
.dsp_ctx
.get_persistent_variable_index(name)
.or_else(|_| Err(JITCompileError::UndefinedVariable(name.to_string())))?;
let persistent_vars = self
.variables
.get("&pv")
.ok_or_else(|| JITCompileError::UndefinedVariable("&pv".to_string()))?;
let pvs = self.builder.use_var(*persistent_vars);
self.builder.ins().store(
MemFlags::new(),
value,
pvs,
Offset32::new(pv_index as i32 * F64.bytes() as i32),
);
} else {
let variable = self
.variables
.get(name)
.ok_or_else(|| JITCompileError::UndefinedVariable(name.to_string()))?;
self.builder.def_var(*variable, value);
}
Ok(value)
}
ASTNode::BinOp(op, a, b) => {
let value_a = self.compile(a)?;
let value_b = self.compile(b)?;
let value = match op {
ASTBinOp::Add => self.builder.ins().fadd(value_a, value_b),
ASTBinOp::Sub => self.builder.ins().fsub(value_a, value_b),
ASTBinOp::Mul => self.builder.ins().fmul(value_a, value_b),
ASTBinOp::Div => self.builder.ins().fdiv(value_a, value_b),
ASTBinOp::Eq => {
let cmp_res = self.builder.ins().fcmp(FloatCC::Equal, value_a, value_b);
self.ins_b_to_f64(cmp_res)
}
ASTBinOp::Ne => {
let cmp_res = self.builder.ins().fcmp(FloatCC::Equal, value_a, value_b);
let bnot = self.builder.ins().bnot(cmp_res);
let bint = self.builder.ins().bint(I32, bnot);
self.builder.ins().fcvt_from_uint(F64, bint)
}
ASTBinOp::Ge => {
let cmp_res =
self.builder.ins().fcmp(FloatCC::GreaterThanOrEqual, value_a, value_b);
self.ins_b_to_f64(cmp_res)
}
ASTBinOp::Le => {
let cmp_res =
self.builder.ins().fcmp(FloatCC::LessThanOrEqual, value_a, value_b);
self.ins_b_to_f64(cmp_res)
}
ASTBinOp::Gt => {
let cmp_res =
self.builder.ins().fcmp(FloatCC::GreaterThan, value_a, value_b);
self.ins_b_to_f64(cmp_res)
}
ASTBinOp::Lt => {
let cmp_res = self.builder.ins().fcmp(FloatCC::LessThan, value_a, value_b);
self.ins_b_to_f64(cmp_res)
}
};
Ok(value)
}
ASTNode::Call(name, dsp_node_uid, args) => {
let func = self
.dsp_node_functions
.get(name)
.ok_or_else(|| JITCompileError::UndefinedDSPNode(name.to_string()))?
.clone();
let node_type = func.0;
let func_id = func.1;
let ptr_type = self.module.target_config().pointer_type();
let mut dsp_node_fun_params = vec![];
let mut i = 0;
let mut arg_idx = 0;
while let Some(bit) = node_type.signature(i) {
match bit {
DSPNodeSigBit::Value => {
if arg_idx >= args.len() {
return Err(JITCompileError::NotEnoughArgsInCall(
name.to_string(),
*dsp_node_uid,
));
}
dsp_node_fun_params.push(self.compile(&args[arg_idx])?);
arg_idx += 1;
}
DSPNodeSigBit::DSPStatePtr => {
let state_var = self.variables.get("&state").ok_or_else(|| {
JITCompileError::UndefinedVariable("&state".to_string())
})?;
dsp_node_fun_params.push(self.builder.use_var(*state_var));
}
DSPNodeSigBit::NodeStatePtr => {
let node_state_index = match self
.dsp_ctx
.add_dsp_node_instance(node_type.clone(), *dsp_node_uid)
{
Err(e) => {
return Err(JITCompileError::NodeStateError(e, *dsp_node_uid));
}
Ok(idx) => idx,
};
let fstate_var = self.variables.get("&fstate").ok_or_else(|| {
JITCompileError::UndefinedVariable("&fstate".to_string())
})?;
let fptr = self.builder.use_var(*fstate_var);
let func_state = self.builder.ins().load(
ptr_type,
MemFlags::new(),
fptr,
Offset32::new(node_state_index as i32 * self.ptr_w as i32),
);
dsp_node_fun_params.push(func_state);
}
DSPNodeSigBit::MultReturnPtr => {
let ret_var = self.variables.get("&rv").ok_or_else(|| {
JITCompileError::UndefinedVariable("&rv".to_string())
})?;
dsp_node_fun_params.push(self.builder.use_var(*ret_var));
}
}
i += 1;
}
let local_callee =
self.module.declare_func_in_func(func_id, &mut self.builder.func);
let call = self.builder.ins().call(local_callee, &dsp_node_fun_params);
Ok(self.builder.inst_results(call)[0])
}
ASTNode::If(cond, then, els) => {
let condition_value = if let ASTNode::BinOp(op, a, b) = cond.as_ref() {
let val = match op {
ASTBinOp::Eq => {
let a = self.compile(a)?;
let b = self.compile(b)?;
self.builder.ins().fcmp(FloatCC::Equal, a, b)
}
ASTBinOp::Ne => {
let a = self.compile(a)?;
let b = self.compile(b)?;
let eq = self.builder.ins().fcmp(FloatCC::Equal, a, b);
self.builder.ins().bnot(eq)
}
ASTBinOp::Gt => {
let a = self.compile(a)?;
let b = self.compile(b)?;
self.builder.ins().fcmp(FloatCC::GreaterThan, a, b)
}
ASTBinOp::Lt => {
let a = self.compile(a)?;
let b = self.compile(b)?;
self.builder.ins().fcmp(FloatCC::LessThan, a, b)
}
ASTBinOp::Ge => {
let a = self.compile(a)?;
let b = self.compile(b)?;
self.builder.ins().fcmp(FloatCC::GreaterThanOrEqual, a, b)
}
ASTBinOp::Le => {
let a = self.compile(a)?;
let b = self.compile(b)?;
self.builder.ins().fcmp(FloatCC::LessThanOrEqual, a, b)
}
_ => self.compile(cond)?,
};
val
} else {
let res = self.compile(cond)?;
let cmpv = self.builder.ins().f64const(0.5);
self.builder.ins().fcmp(FloatCC::GreaterThanOrEqual, res, cmpv)
};
let then_block = self.builder.create_block();
let else_block = self.builder.create_block();
let merge_block = self.builder.create_block();
self.builder.append_block_param(merge_block, F64);
self.builder.ins().brz(condition_value, else_block, &[]);
self.builder.ins().jump(then_block, &[]);
self.builder.switch_to_block(then_block);
self.builder.seal_block(then_block);
let then_return = self.compile(then)?;
self.builder.ins().jump(merge_block, &[then_return]);
self.builder.switch_to_block(else_block);
self.builder.seal_block(else_block);
let else_return = if let Some(els) = els {
self.compile(els)?
} else {
self.builder.ins().f64const(0.0)
};
self.builder.ins().jump(merge_block, &[else_return]);
self.builder.switch_to_block(merge_block);
self.builder.seal_block(merge_block);
let phi = self.builder.block_params(merge_block)[0];
Ok(phi)
}
ASTNode::Stmts(stmts) => {
let mut value = None;
for ast in stmts {
value = Some(self.compile(ast)?);
}
if let Some(value) = value {
Ok(value)
} else {
Ok(self.builder.ins().f64const(0.0))
}
}
}
}
}
#[macro_export]
macro_rules! stateful_dsp_node_type {
($node_type: ident, $struct_type: ident =>
$func_name: ident $jit_name: literal $signature: literal) => {
struct $node_type;
impl $node_type {
fn new_ref() -> std::rc::Rc<Self> {
std::rc::Rc::new(Self {})
}
}
impl DSPNodeType for $node_type {
fn name(&self) -> &str {
$jit_name
}
fn function_ptr(&self) -> *const u8 {
$func_name as *const u8
}
fn signature(&self, i: usize) -> Option<DSPNodeSigBit> {
match $signature.chars().nth(i).unwrap() {
'v' => Some(DSPNodeSigBit::Value),
'D' => Some(DSPNodeSigBit::DSPStatePtr),
'S' => Some(DSPNodeSigBit::NodeStatePtr),
'M' => Some(DSPNodeSigBit::MultReturnPtr),
_ => None,
}
}
fn has_return_value(&self) -> bool {
$signature.find("r").is_some()
}
fn reset_state(&self, dsp_state: *mut DSPState, state_ptr: *mut u8) {
let ptr = state_ptr as *mut $struct_type;
unsafe {
(*ptr).reset(&mut (*dsp_state));
}
}
fn allocate_state(&self) -> Option<*mut u8> {
Some(Box::into_raw(Box::new($struct_type::default())) as *mut u8)
}
fn deallocate_state(&self, ptr: *mut u8) {
unsafe { Box::from_raw(ptr as *mut $struct_type) };
}
}
};
}
pub fn get_nop_function(
lib: Rc<RefCell<DSPNodeTypeLibrary>>,
dsp_ctx: Rc<RefCell<DSPNodeContext>>,
) -> Box<DSPFunction> {
let jit = JIT::new(lib, dsp_ctx);
jit.compile(ASTFun::new(Box::new(ASTNode::Lit(0.0)))).expect("No compile error")
}