use anyhow::{anyhow, Ok, Result};
use codegen::ir::{AtomicRmwOp, FuncRef, GlobalValue, StackSlot};
use cranelift::prelude::*;
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{DataDescription, DataId, FuncId, FuncOrDataId, Linkage, Module};
use cranelift_object::{ObjectBuilder, ObjectModule};
use std::collections::HashMap;
use std::iter::zip;
use std::sync::{Mutex, MutexGuard};
use target_lexicon::{Endianness, PointerWidth, Triple};
use crate::ast::{Ast, AstKind};
use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
use crate::execution::compiler::CompilerOptions;
use crate::execution::module::{
CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
};
use crate::execution::scalar::RealType;
use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
pub struct CraneliftModule<M: Module> {
builder_context: FunctionBuilderContext,
ctx: codegen::Context,
module: Mutex<M>,
layout: DataLayout,
indices_id: DataId,
constants_id: DataId,
model_index_id: DataId,
thread_counter: Option<DataId>,
int_type: types::Type,
real_type: types::Type,
real_ptr_type: types::Type,
int_ptr_type: types::Type,
threaded: bool,
}
pub type CraneliftJitModule = CraneliftModule<JITModule>;
pub type CraneliftObjectModule = CraneliftModule<ObjectModule>;
impl<M: Module> CraneliftModule<M> {
fn declare_function(&mut self, name: &str) -> Result<FuncId> {
let mut module = self.module.lock().unwrap();
let id = module.declare_function(name, Linkage::Export, &self.ctx.func.signature)?;
module.define_function(id, &mut self.ctx)?;
module.clear_context(&mut self.ctx);
Ok(id)
}
fn compile_barrier_init(&mut self) -> Result<FuncId> {
self.ctx.func.signature.params.clear();
self.ctx.func.signature.returns.clear();
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let mut module = self.module.lock().unwrap();
let thread_counter =
module.declare_data_in_func(self.thread_counter.unwrap(), builder.func);
let entry_block = builder.create_block();
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let thread_counter = builder
.ins()
.global_value(self.int_ptr_type, thread_counter);
let zero = builder.ins().iconst(self.int_type, 0);
builder
.ins()
.store(MemFlags::new(), zero, thread_counter, 0);
builder.ins().return_(&[]);
builder.finalize();
let name = "barrier_init";
let id = module.declare_function(name, Linkage::Export, &self.ctx.func.signature)?;
module.define_function(id, &mut self.ctx)?;
module.clear_context(&mut self.ctx);
Ok(id)
}
fn compile_barrier(&mut self) -> Result<FuncId> {
self.ctx.func.signature.params.clear();
self.ctx.func.signature.returns.clear();
let arg_types = &[self.int_type];
for ty in arg_types {
self.ctx.func.signature.params.push(AbiParam::new(*ty));
}
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let mut module = self.module.lock().unwrap();
let thread_counter =
module.declare_data_in_func(self.thread_counter.unwrap(), builder.func);
let entry_block = builder.create_block();
let wait_loop_block = builder.create_block();
let barrier_done_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
let thread_count = builder.block_params(entry_block)[0];
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let thread_counter = builder
.ins()
.global_value(self.int_ptr_type, thread_counter);
let one = builder.ins().iconst(self.int_type, 1);
builder.ins().atomic_rmw(
self.int_type,
MemFlags::new(),
AtomicRmwOp::Add,
thread_counter,
one,
);
builder.ins().jump(wait_loop_block, &[]);
builder.switch_to_block(wait_loop_block);
let current_value =
builder
.ins()
.atomic_load(self.int_type, MemFlags::new(), thread_counter);
let all_threads_done = builder.ins().icmp(
IntCC::UnsignedGreaterThanOrEqual,
current_value,
thread_count,
);
builder.ins().brif(
all_threads_done,
barrier_done_block,
&[],
wait_loop_block,
&[],
);
builder.seal_block(wait_loop_block);
builder.switch_to_block(barrier_done_block);
builder.seal_block(barrier_done_block);
builder.ins().return_(&[]);
builder.finalize();
let name = "barrier";
let id = module.declare_function(name, Linkage::Export, &self.ctx.func.signature)?;
module.define_function(id, &mut self.ctx)?;
module.clear_context(&mut self.ctx);
Ok(id)
}
fn compile_calc_out_grad(
&mut self,
_func_id: &FuncId,
model: &DiscreteModel,
) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &[
"t",
"u",
"du",
"data",
"ddata",
"out",
"dout",
"threadId",
"threadDim",
];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if let Some(out) = model.out() {
codegen.jit_compile_tensor(out, None, true)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("calc_out_grad")
}
fn compile_rhs_grad(&mut self, _func_id: &FuncId, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &[
"t",
"u",
"du",
"data",
"ddata",
"rr",
"drr",
"threadId",
"threadDim",
];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.state_dep_defns() {
codegen.jit_compile_tensor(a, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let res = *codegen.variables.get("drr").unwrap();
codegen.jit_compile_tensor(model.rhs(), Some(res), true)?;
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("rhs_grad")
}
fn compile_reset_grad(&mut self, _func_id: &FuncId, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &[
"t",
"u",
"du",
"data",
"ddata",
"reset",
"dreset",
"threadId",
"threadDim",
];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if let Some(reset) = model.reset() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_post_f_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let dreset_ptr = *codegen.variables.get("dreset").unwrap();
codegen.jit_compile_tensor(reset, Some(dreset_ptr), true)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("reset_grad")
}
fn compile_calc_stop_grad(
&mut self,
_func_id: &FuncId,
model: &DiscreteModel,
) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &[
"t",
"u",
"du",
"data",
"ddata",
"root",
"droot",
"threadId",
"threadDim",
];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if let Some(stop) = model.stop() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_post_f_defns() {
codegen.jit_compile_tensor(tensor, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let droot_ptr = *codegen.variables.get("droot").unwrap();
codegen.jit_compile_tensor(stop, Some(droot_ptr), true)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("calc_stop_grad")
}
fn compile_set_inputs_grad(
&mut self,
_func_id: &FuncId,
model: &DiscreteModel,
) -> Result<FuncId> {
let arg_types = &[
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
];
let arg_names = &["inputs", "dinputs", "data", "ddata", "model_index"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let model_index_ptr = codegen
.builder
.ins()
.global_value(codegen.int_ptr_type, codegen.model_index_global);
let model_index = codegen
.builder
.use_var(*codegen.variables.get("model_index").unwrap());
codegen
.builder
.ins()
.store(codegen.mem_flags, model_index, model_index_ptr, 0);
let base_data_ptr = codegen.variables.get("ddata").unwrap();
let base_data_ptr = codegen.builder.use_var(*base_data_ptr);
codegen.jit_compile_inputs(model, base_data_ptr, true, false);
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("set_inputs_grad")
}
fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[self.int_type, self.int_type];
let arg_names = &["threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let mut nbarrier = 0;
#[allow(clippy::explicit_counter_loop)]
for a in model.constant_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("set_constants")
}
fn compile_set_u0_grad(&mut self, _func_id: &FuncId, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["u0", "du0", "data", "ddata", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let mut nbarrier = 0;
#[allow(clippy::explicit_counter_loop)]
for a in model.input_dep_defns() {
codegen.jit_compile_tensor(a, None, true)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
codegen.jit_compile_tensor(
model.state(),
Some(*codegen.variables.get("du0").unwrap()),
true,
)?;
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("set_u0_grad")
}
fn new(
triple: Triple,
model: &DiscreteModel,
threaded: bool,
mut module: M,
real_type: RealType,
) -> Result<Self> {
let ptr_type = match triple.pointer_width().unwrap() {
PointerWidth::U16 => types::I16,
PointerWidth::U32 => types::I32,
PointerWidth::U64 => types::I64,
};
let layout = DataLayout::new(model);
let int_type = types::I32;
let real_type_cranelift = match real_type {
RealType::F32 => types::F32,
RealType::F64 => types::F64,
};
let mut data_description = DataDescription::new();
data_description
.define_zeroinit(layout.constants().len() * (real_type_cranelift.bytes() as usize));
let constants_id = module.declare_data("constants", Linkage::Local, true, false)?;
module.define_data(constants_id, &data_description)?;
let mut vec8: Vec<u8> = vec![];
for elem in layout.indices() {
if int_type == types::I64 {
let elemi64 = i64::from(*elem);
let conv = match triple.endianness().unwrap() {
Endianness::Little => elemi64.to_le_bytes(),
Endianness::Big => elemi64.to_be_bytes(),
};
vec8.extend(conv.into_iter());
} else {
let conv = match triple.endianness().unwrap() {
Endianness::Little => elem.to_le_bytes(),
Endianness::Big => elem.to_be_bytes(),
};
vec8.extend(conv.into_iter());
};
}
let mut data_description = DataDescription::new();
data_description.define(vec8.into_boxed_slice());
let indices_id = module.declare_data("indices", Linkage::Local, false, false)?;
module.define_data(indices_id, &data_description)?;
let mut data_description = DataDescription::new();
data_description.define_zeroinit(int_type.bytes().try_into().unwrap());
let model_index_id = module.declare_data("model_index", Linkage::Local, true, false)?;
module.define_data(model_index_id, &data_description)?;
let mut thread_counter = None;
if threaded {
let mut data_description = DataDescription::new();
data_description.define_zeroinit(int_type.bytes().try_into().unwrap());
let the_thread_counter =
module.declare_data("thread_counter", Linkage::Local, true, false)?;
module.define_data(the_thread_counter, &data_description)?;
thread_counter = Some(the_thread_counter);
}
let mut ret = Self {
builder_context: FunctionBuilderContext::new(),
ctx: module.make_context(),
module: Mutex::new(module),
indices_id,
constants_id,
model_index_id,
int_type,
real_type: real_type_cranelift,
real_ptr_type: ptr_type,
int_ptr_type: ptr_type,
layout,
threaded,
thread_counter,
};
if threaded {
ret.compile_barrier_init()?;
ret.compile_barrier()?;
}
let set_u0 = ret.compile_set_u0(model)?;
let _calc_stop = ret.compile_calc_stop(model)?;
let reset = ret.compile_reset(model)?;
let rhs = ret.compile_rhs(model)?;
let _mass = ret.compile_mass(model)?;
let calc_out = ret.compile_calc_out(model)?;
let _set_id = ret.compile_set_id(model)?;
let _get_dims = ret.compile_get_dims(model)?;
let set_inputs = ret.compile_set_inputs(model)?;
let _get_inputs = ret.compile_get_inputs(model)?;
let _set_constants = ret.compile_set_constants(model)?;
let tensor_info = ret
.layout
.tensors()
.map(|(name, is_constant)| (name.to_string(), is_constant))
.collect::<Vec<_>>();
for (tensor, is_constant) in tensor_info {
if is_constant {
ret.compile_get_constant(model, tensor.as_str())?;
} else {
ret.compile_get_tensor(model, tensor.as_str())?;
}
}
let _set_u0_grad = ret.compile_set_u0_grad(&set_u0, model)?;
let _rhs_grad = ret.compile_rhs_grad(&rhs, model)?;
let _reset_grad = ret.compile_reset_grad(&reset, model)?;
let _calc_stop_grad = ret.compile_calc_stop_grad(&_calc_stop, model)?;
let _calc_out_grad = ret.compile_calc_out_grad(&calc_out, model)?;
let _set_inputs_grad = ret.compile_set_inputs_grad(&set_inputs, model)?;
Ok(ret)
}
fn compile_set_u0(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["u0", "data", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let mut nbarrier = 0;
#[allow(clippy::explicit_counter_loop)]
for a in model.input_dep_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
codegen.jit_compile_tensor(
model.state(),
Some(*codegen.variables.get("u0").unwrap()),
false,
)?;
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("set_u0")
}
fn compile_calc_out(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["t", "u", "data", "out", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if let Some(out) = model.out() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.state_dep_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.state_dep_post_f_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
codegen.jit_compile_tensor(out, None, false)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("calc_out")
}
fn compile_calc_stop(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["t", "u", "data", "root", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if let Some(stop) = model.stop() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.state_dep_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.state_dep_post_f_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let root = *codegen.variables.get("root").unwrap();
codegen.jit_compile_tensor(stop, Some(root), false)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("calc_stop")
}
fn compile_reset(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["t", "u", "data", "reset", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if let Some(reset) = model.reset() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for tensor in model.state_dep_post_f_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let reset_ptr = *codegen.variables.get("reset").unwrap();
codegen.jit_compile_tensor(reset, Some(reset_ptr), false)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("reset")
}
fn compile_rhs(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["t", "u", "data", "rr", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.state_dep_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let res = *codegen.variables.get("rr").unwrap();
codegen.jit_compile_tensor(model.rhs(), Some(res), false)?;
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("rhs")
}
fn compile_mass(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.real_type,
self.real_ptr_type,
self.real_ptr_type,
self.real_ptr_type,
self.int_type,
self.int_type,
];
let arg_names = &["t", "dudt", "data", "rr", "threadId", "threadDim"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
if model.state_dot().is_some() && model.lhs().is_some() {
let mut nbarrier = 0;
for tensor in model.time_dep_defns() {
codegen.jit_compile_tensor(tensor, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
for a in model.dstate_dep_defns() {
codegen.jit_compile_tensor(a, None, false)?;
codegen.jit_compile_call_barrier(nbarrier);
nbarrier += 1;
}
let lhs = model.lhs().unwrap();
let res = codegen.variables.get("rr").unwrap();
codegen.jit_compile_tensor(lhs, Some(*res), false)?;
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("mass")
}
fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[
self.int_ptr_type,
self.int_ptr_type,
self.int_ptr_type,
self.int_ptr_type,
self.int_ptr_type,
self.int_ptr_type,
self.int_ptr_type,
];
let arg_names = &[
"states",
"inputs",
"outputs",
"data",
"stop",
"has_mass",
"has_reset",
];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let number_of_states = i64::try_from(model.state().nnz()).unwrap();
let number_of_inputs =
i64::try_from(model.input().map(|inp| inp.nnz()).unwrap_or(0)).unwrap();
let number_of_outputs = match model.out() {
Some(out) => i64::try_from(out.nnz()).unwrap(),
None => 0,
};
let number_of_stop = if let Some(stop) = model.stop() {
i64::try_from(stop.nnz()).unwrap()
} else {
0
};
let has_mass = match model.lhs().is_some() {
true => 1,
false => 0,
};
let has_reset = match model.reset().is_some() {
true => 1,
false => 0,
};
let data_len = i64::try_from(codegen.layout.data().len()).unwrap();
for (val, name) in [
(number_of_states, "states"),
(number_of_inputs, "inputs"),
(number_of_outputs, "outputs"),
(data_len, "data"),
(number_of_stop, "stop"),
(has_mass, "has_mass"),
(has_reset, "has_reset"),
] {
let val = codegen.builder.ins().iconst(codegen.int_type, val);
let ptr = codegen.variables.get(name).unwrap();
let ptr = codegen.builder.use_var(*ptr);
codegen.builder.ins().store(codegen.mem_flags, val, ptr, 0);
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("get_dims")
}
fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result<FuncId> {
let arg_types = &[self.real_ptr_type, self.real_ptr_type, self.int_ptr_type];
let arg_names = &["data", "tensor_data", "tensor_size"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let tensor_ptr = codegen.variables.get(name).unwrap();
let tensor_ptr = codegen.builder.use_var(*tensor_ptr);
let tensor_size =
i64::try_from(codegen.layout.get_layout(name).unwrap().nnz()).unwrap();
let tensor_size = codegen.builder.ins().iconst(codegen.int_type, tensor_size);
for (val, name) in [(tensor_ptr, "tensor_data"), (tensor_size, "tensor_size")] {
let ptr = codegen.variables.get(name).unwrap();
let ptr = codegen.builder.use_var(*ptr);
codegen.builder.ins().store(codegen.mem_flags, val, ptr, 0);
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function(format!("get_tensor_{name}").as_str())
}
fn compile_get_constant(&mut self, model: &DiscreteModel, name: &str) -> Result<FuncId> {
let arg_types = &[self.real_ptr_type, self.int_ptr_type];
let arg_names = &["tensor_data", "tensor_size"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let tensor_ptr = codegen.variables.get(name).unwrap();
let tensor_ptr = codegen.builder.use_var(*tensor_ptr);
let tensor_size =
i64::try_from(codegen.layout.get_layout(name).unwrap().nnz()).unwrap();
let tensor_size = codegen.builder.ins().iconst(codegen.int_type, tensor_size);
for (val, name) in [(tensor_ptr, "tensor_data"), (tensor_size, "tensor_size")] {
let ptr = codegen.variables.get(name).unwrap();
let ptr = codegen.builder.use_var(*ptr);
codegen.builder.ins().store(codegen.mem_flags, val, ptr, 0);
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function(format!("get_constant_{name}").as_str())
}
fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[self.real_ptr_type, self.real_ptr_type, self.int_type];
let arg_names = &["inputs", "data", "model_index"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let model_index_ptr = codegen
.builder
.ins()
.global_value(codegen.int_ptr_type, codegen.model_index_global);
let model_index = codegen
.builder
.use_var(*codegen.variables.get("model_index").unwrap());
codegen
.builder
.ins()
.store(codegen.mem_flags, model_index, model_index_ptr, 0);
let base_data_ptr = codegen.variables.get("data").unwrap();
let base_data_ptr = codegen.builder.use_var(*base_data_ptr);
codegen.jit_compile_inputs(model, base_data_ptr, false, false);
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("set_inputs")
}
fn compile_get_inputs(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[self.real_ptr_type, self.real_ptr_type];
let arg_names = &["inputs", "data"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let base_data_ptr = codegen.variables.get("data").unwrap();
let base_data_ptr = codegen.builder.use_var(*base_data_ptr);
codegen.jit_compile_inputs(model, base_data_ptr, false, true);
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("get_inputs")
}
fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FuncId> {
let arg_types = &[self.real_ptr_type];
let arg_names = &["id"];
{
let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types);
let mut id_index = 0usize;
for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
let id_start_index = codegen
.builder
.ins()
.iconst(codegen.int_type, i64::try_from(id_index).unwrap());
let blk_start_index = codegen.builder.ins().iconst(codegen.int_type, 0);
let blk_block = codegen.builder.create_block();
let curr_blk_index = codegen
.builder
.append_block_param(blk_block, codegen.int_type);
codegen
.builder
.ins()
.jump(blk_block, &[blk_start_index.into()]);
codegen.builder.switch_to_block(blk_block);
let input_id_ptr = codegen.variables.get("id").unwrap();
let input_id_ptr = codegen.builder.use_var(*input_id_ptr);
let curr_id_index = codegen.builder.ins().iadd(id_start_index, curr_blk_index);
let indexed_id_ptr =
codegen.ptr_add_offset(codegen.real_type, input_id_ptr, curr_id_index);
let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
let is_algebraic_value = codegen.fconst(is_algebraic_float);
codegen.builder.ins().store(
codegen.mem_flags,
is_algebraic_value,
indexed_id_ptr,
0,
);
let one = codegen.builder.ins().iconst(codegen.int_type, 1);
let next_index = codegen.builder.ins().iadd(curr_blk_index, one);
let loop_while = codegen.builder.ins().icmp_imm(
IntCC::UnsignedLessThan,
next_index,
i64::try_from(blk.nnz()).unwrap(),
);
let post_block = codegen.builder.create_block();
codegen.builder.ins().brif(
loop_while,
blk_block,
&[next_index.into()],
post_block,
&[],
);
codegen.builder.seal_block(blk_block);
codegen.builder.seal_block(post_block);
codegen.builder.switch_to_block(post_block);
id_index += blk.nnz();
}
codegen.builder.ins().return_(&[]);
codegen.builder.finalize();
}
self.declare_function("set_id")
}
}
impl<M: Module + Send + 'static> CodegenModule for CraneliftModule<M> {}
impl CodegenModuleCompile for CraneliftModule<ObjectModule> {
fn from_discrete_model(
model: &DiscreteModel,
options: CompilerOptions,
triple: Option<Triple>,
real_type: RealType,
_code: Option<&str>,
) -> Result<Self> {
let thread_dim = options.mode.thread_dim(model.state().nnz());
let threaded = thread_dim > 1;
let triple = triple.unwrap_or(Triple::host());
let mut flag_builder = settings::builder();
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
flag_builder.set("opt_level", "speed").unwrap();
let flags = settings::Flags::new(flag_builder);
let isa = isa::lookup(triple.clone())?.finish(flags)?;
let builder =
ObjectBuilder::new(isa, "diffsol", cranelift_module::default_libcall_names())?;
let module = ObjectModule::new(builder);
Self::new(triple, model, threaded, module, real_type)
}
}
impl CodegenModuleCompile for CraneliftModule<JITModule> {
fn from_discrete_model(
model: &DiscreteModel,
options: CompilerOptions,
triple: Option<Triple>,
real_type: RealType,
_code: Option<&str>,
) -> Result<Self> {
let thread_dim = options.mode.thread_dim(model.state().nnz());
let threaded = thread_dim > 1;
let triple = triple.unwrap_or(Triple::host());
let mut flag_builder = settings::builder();
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
flag_builder.set("opt_level", "speed").unwrap();
let flags = settings::Flags::new(flag_builder);
let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
panic!("host machine is not supported: {msg}");
});
let isa = isa_builder.finish(flags).unwrap();
let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
for i in 0..crate::execution::functions::FUNCTIONS.len() {
let (f_name, f_ptr, df_ptr) = match real_type {
RealType::F32 => (
crate::execution::functions::FUNCTIONS_F32[i].0,
crate::execution::functions::FUNCTIONS_F32[i].1 as *const u8,
crate::execution::functions::FUNCTIONS_F32[i].2 as *const u8,
),
RealType::F64 => (
crate::execution::functions::FUNCTIONS_F64[i].0,
crate::execution::functions::FUNCTIONS_F64[i].1 as *const u8,
crate::execution::functions::FUNCTIONS_F64[i].2 as *const u8,
),
};
builder.symbol(f_name, f_ptr);
builder.symbol(
CraneliftCodeGen::<JITModule>::get_function_name(f_name, true),
df_ptr,
);
}
for i in 0..crate::execution::functions::TWO_ARG_FUNCTIONS.len() {
let (f_name, f_ptr, df_ptr) = match real_type {
RealType::F32 => (
crate::execution::functions::TWO_ARG_FUNCTIONS_F32[i].0,
crate::execution::functions::TWO_ARG_FUNCTIONS_F32[i].1 as *const u8,
crate::execution::functions::TWO_ARG_FUNCTIONS_F32[i].2 as *const u8,
),
RealType::F64 => (
crate::execution::functions::TWO_ARG_FUNCTIONS_F64[i].0,
crate::execution::functions::TWO_ARG_FUNCTIONS_F64[i].1 as *const u8,
crate::execution::functions::TWO_ARG_FUNCTIONS_F64[i].2 as *const u8,
),
};
builder.symbol(f_name, f_ptr);
builder.symbol(
CraneliftCodeGen::<JITModule>::get_function_name(f_name, true),
df_ptr,
);
}
let module = JITModule::new(builder);
Self::new(triple, model, threaded, module, real_type)
}
}
impl CodegenModuleEmit for CraneliftModule<ObjectModule> {
fn to_object(self) -> Result<Vec<u8>> {
let module = Mutex::into_inner(self.module).unwrap();
module.finish().emit().map_err(|e| anyhow!(e))
}
}
impl CodegenModuleJit for CraneliftModule<JITModule> {
fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
let mut result = HashMap::new();
let mut module = self.module.lock().unwrap();
module.finalize_definitions()?;
for (func, decl) in module.declarations().get_functions() {
if Linkage::Import == decl.linkage {
continue;
}
let addr = module.get_finalized_function(func);
result.insert(decl.name.as_ref().unwrap().clone(), addr);
}
Ok(result)
}
}
struct CraneliftCodeGen<'a, M: Module> {
int_type: types::Type,
real_type: types::Type,
real_ptr_type: types::Type,
int_ptr_type: types::Type,
builder: FunctionBuilder<'a>,
module: MutexGuard<'a, M>,
tensor_ptr: Option<Value>,
variables: HashMap<String, Variable>,
mem_flags: MemFlags,
functions: HashMap<String, FuncRef>,
layout: &'a DataLayout,
indices: GlobalValue,
constants: GlobalValue,
model_index_global: GlobalValue,
threaded: bool,
}
impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> {
fn fconst(&mut self, value: f64) -> Value {
match self.real_type {
types::F32 => self.builder.ins().f32const(value as f32),
types::F64 => self.builder.ins().f64const(value),
_ => panic!("unexpected real type"),
}
}
fn ptr_add_offset_i64(&mut self, elmt_ty: types::Type, ptr: Value, offset: i64) -> Value {
let ptr_ty = self.real_ptr_type;
let width = elmt_ty.bytes() as i64;
let offset_bytes = self.builder.ins().iconst(ptr_ty, offset * width);
self.builder.ins().iadd(ptr, offset_bytes)
}
fn ptr_add_offset(&mut self, elmt_ty: types::Type, ptr: Value, offset: Value) -> Value {
let width = elmt_ty.bytes() as i64;
let ptr_ty = self.real_ptr_type;
let width_value = self.builder.ins().iconst(ptr_ty, width);
let offset_ptr = if self.int_type != ptr_ty {
self.builder.ins().sextend(ptr_ty, offset)
} else {
offset
};
let offset_bytes = self.builder.ins().imul(offset_ptr, width_value);
self.builder.ins().iadd(ptr, offset_bytes)
}
fn jit_compile_call_barrier(&mut self, nbarrier: i64) {
if !self.threaded {
return;
}
let thread_dim = self.variables.get("threadDim").unwrap();
let thread_dim = self.builder.use_var(*thread_dim);
let nbarrier = self.builder.ins().iconst(self.int_type, nbarrier + 1);
let thread_dim_mul_nbarrier = self.builder.ins().imul(thread_dim, nbarrier);
let barrier = self.get_function("barrier", false).unwrap();
self.builder.ins().call(barrier, &[thread_dim_mul_nbarrier]);
}
fn jit_compile_expr(
&mut self,
name: &str,
expr: &Ast,
index: &[Value],
elmt: &TensorBlock,
expr_index: Value,
) -> Result<Value> {
let name = elmt.name().unwrap_or(name);
match &expr.kind {
AstKind::Binop(binop) => {
let lhs =
self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
let rhs =
self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
match binop.op {
'*' => Ok(self.builder.ins().fmul(lhs, rhs)),
'/' => Ok(self.builder.ins().fdiv(lhs, rhs)),
'-' => Ok(self.builder.ins().fsub(lhs, rhs)),
'+' => Ok(self.builder.ins().fadd(lhs, rhs)),
unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
}
}
AstKind::Monop(monop) => {
let child =
self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
match monop.op {
'-' => Ok(self.builder.ins().fneg(child)),
unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
}
}
AstKind::Call(call) => match self.get_function(call.fn_name, call.is_tangent) {
Some(function) => {
let mut args = Vec::new();
for arg in call.args.iter() {
let arg_val =
self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
args.push(arg_val);
}
let call = self.builder.ins().call(function, &args);
let ret_value = self.builder.inst_results(call)[0];
Ok(ret_value)
}
None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
},
AstKind::CallArg(arg) => {
self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
}
AstKind::Number(value) => Ok(self.fconst(*value)),
AstKind::Name(iname) => {
if iname.name == "N" {
if iname.is_tangent {
return Ok(self.fconst(0.0));
}
let var = self
.variables
.get("model_index")
.ok_or_else(|| anyhow!("N used where model_index is unavailable"))?;
let model_index = self.builder.use_var(*var);
return Ok(self
.builder
.ins()
.fcvt_from_sint(self.real_type, model_index));
}
let ptr = if iname.is_tangent {
if self.layout.is_constant(iname.name) {
return Ok(self.fconst(0.0));
}
let name = self.get_tangent_tensor_name(iname.name);
self.builder
.use_var(*self.variables.get(name.as_str()).unwrap())
} else {
self.builder
.use_var(*self.variables.get(iname.name).unwrap())
};
if iname.name == "t" {
return Ok(ptr);
}
let layout = self.layout.get_layout(iname.name).unwrap();
let iname_elmt_index = if layout.is_dense() {
let mut no_transform = true;
let mut iname_index = Vec::new();
for (i, c) in iname.indices.iter().enumerate() {
let pi = elmt
.indices()
.iter()
.position(|x| x == c)
.unwrap_or(elmt.indices().len());
if let Some(indice_ast) = iname.indice.as_ref() {
let Some(indice) = indice_ast.kind.as_indice() else {
return Err(anyhow!("invalid index expression '{}'", indice_ast));
};
let start_intval =
self.jit_compile_integer_expr(indice.first.as_ref())?;
let index_pi = if pi >= index.len() {
self.builder.ins().iconst(self.int_type, 0)
} else {
index[pi]
};
let index_pi = self.builder.ins().iadd(start_intval, index_pi);
iname_index.push(index_pi);
} else {
let index_pi = if pi >= index.len() {
self.builder.ins().iconst(self.int_type, 0)
} else {
index[pi]
};
iname_index.push(index_pi);
}
no_transform = no_transform && pi == i;
}
let iname_index = iname_index
.into_iter()
.enumerate()
.map(|(i, idx)| {
if layout.shape()[i] == 1 {
self.builder.ins().iconst(self.int_type, 0)
} else {
idx
}
})
.collect::<Vec<_>>();
if !iname_index.is_empty() {
let mut iname_elmt_index = *iname_index.last().unwrap();
let mut stride = 1u64;
for i in (0..iname_index.len() - 1).rev() {
let iname_i = iname_index[i];
let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
stride *= shapei;
let stride_intval = self
.builder
.ins()
.iconst(self.int_type, i64::try_from(stride).unwrap());
let stride_mul_i = self.builder.ins().imul(stride_intval, iname_i);
iname_elmt_index =
self.builder.ins().iadd(iname_elmt_index, stride_mul_i);
}
iname_elmt_index
} else {
self.builder.ins().iconst(self.int_type, 0)
}
} else if layout.is_sparse() || layout.is_diagonal() {
let expr_layout = elmt.expr_layout();
if expr_layout != layout {
let permutation =
DataLayout::permutation(elmt, iname.indices.as_slice(), layout);
if let Some(base_binary_layout_index) =
self.layout
.get_binary_layout_index(layout, expr_layout, permutation)
{
let base_binary_layout_index = self.builder.ins().iconst(
self.int_type,
i64::try_from(base_binary_layout_index).unwrap(),
);
let binary_layout_index = self
.builder
.ins()
.iadd(base_binary_layout_index, expr_index);
let indices_array = self
.builder
.ins()
.global_value(self.int_ptr_type, self.indices);
let indices_ptr = self.ptr_add_offset(
self.int_type,
indices_array,
binary_layout_index,
);
let mapped_index = self.builder.ins().load(
self.int_type,
self.mem_flags,
indices_ptr,
0,
);
let is_less_than_zero =
self.builder
.ins()
.icmp_imm(IntCC::SignedLessThan, mapped_index, 0);
let is_less_than_zero_block = self.builder.create_block();
let not_less_than_zero_block = self.builder.create_block();
let merge_block = self.builder.create_block();
let phi_value =
self.builder.append_block_param(merge_block, self.real_type);
self.builder.ins().brif(
is_less_than_zero,
is_less_than_zero_block,
&[],
not_less_than_zero_block,
&[],
);
self.builder.seal_block(is_less_than_zero_block);
self.builder.seal_block(not_less_than_zero_block);
self.builder.switch_to_block(is_less_than_zero_block);
let zero = self.fconst(0.);
self.builder.ins().jump(merge_block, &[zero.into()]);
self.builder.switch_to_block(not_less_than_zero_block);
let value_ptr = self.ptr_add_offset(self.real_type, ptr, mapped_index);
let value = self.builder.ins().load(
self.real_type,
self.mem_flags,
value_ptr,
0,
);
self.builder.ins().jump(merge_block, &[value.into()]);
self.builder.seal_block(merge_block);
self.builder.switch_to_block(merge_block);
return Ok(phi_value);
} else {
expr_index
}
} else {
expr_index
}
} else {
panic!("unexpected layout");
};
let value_ptr = self.ptr_add_offset(self.real_type, ptr, iname_elmt_index);
Ok(self
.builder
.ins()
.load(self.real_type, self.mem_flags, value_ptr, 0))
}
AstKind::NamedGradient(name) => {
let name_str = name.to_string();
let ptr = self
.builder
.use_var(*self.variables.get(name_str.as_str()).unwrap());
Ok(self
.builder
.ins()
.load(self.real_type, self.mem_flags, ptr, 0))
}
AstKind::Index(_) => todo!(),
AstKind::Slice(_) => todo!(),
AstKind::Integer(_) => todo!(),
_ => panic!("unexprected astkind"),
}
}
fn jit_compile_integer_expr(&mut self, expr: &Ast) -> Result<Value> {
match &expr.kind {
AstKind::Integer(value) => Ok(self.builder.ins().iconst(self.int_type, *value)),
AstKind::Number(value) => {
if value.fract() != 0.0 {
return Err(anyhow!(
"non-integer value '{}' in integer expression",
value
));
}
Ok(self.builder.ins().iconst(self.int_type, *value as i64))
}
AstKind::Name(iname) => {
if iname.name == "N" {
let var = self
.variables
.get("model_index")
.ok_or_else(|| anyhow!("N used where model_index is unavailable"))?;
Ok(self.builder.use_var(*var))
} else {
Err(anyhow!(
"unsupported name '{}' in integer expression",
iname.name
))
}
}
AstKind::Monop(monop) => {
let child = self.jit_compile_integer_expr(monop.child.as_ref())?;
match monop.op {
'+' => Ok(child),
'-' => Ok(self.builder.ins().ineg(child)),
_ => Err(anyhow!("unknown integer unary op '{}'", monop.op)),
}
}
AstKind::Binop(binop) => {
let lhs = self.jit_compile_integer_expr(binop.left.as_ref())?;
let rhs = self.jit_compile_integer_expr(binop.right.as_ref())?;
match binop.op {
'+' => Ok(self.builder.ins().iadd(lhs, rhs)),
'-' => Ok(self.builder.ins().isub(lhs, rhs)),
'*' => Ok(self.builder.ins().imul(lhs, rhs)),
'/' => Ok(self.builder.ins().sdiv(lhs, rhs)),
'%' => Ok(self.builder.ins().srem(lhs, rhs)),
_ => Err(anyhow!("unknown integer binary op '{}'", binop.op)),
}
}
_ => Err(anyhow!("unsupported integer expression '{}'", expr)),
}
}
fn get_function_name(name: &str, is_tangent: bool) -> String {
if is_tangent {
format!("{name}__tangent__")
} else {
name.to_owned()
}
}
fn get_function(&mut self, base_name: &str, is_tangent: bool) -> Option<FuncRef> {
let name = Self::get_function_name(base_name, is_tangent);
match self.functions.get(name.as_str()) {
Some(&func) => Some(func),
None => {
match crate::execution::functions::function_num_args(base_name, is_tangent) {
Some(num_args) => {
let mut sig = self.module.make_signature();
for _ in 0..num_args {
sig.params.push(AbiParam::new(self.real_type));
}
sig.returns.push(AbiParam::new(self.real_type));
let callee = self
.module
.declare_function(name.as_str(), Linkage::Import, &sig)
.expect("problem declaring function");
let function = self.module.declare_func_in_func(callee, self.builder.func);
self.functions.insert(name, function);
Some(function)
}
None => {
match self.module.get_name(name.as_str()) {
Some(FuncOrDataId::Func(func_id)) => {
let function =
self.module.declare_func_in_func(func_id, self.builder.func);
self.functions.insert(name, function);
Some(function)
}
_ => None,
}
}
}
}
}
}
fn jit_compile_tensor(
&mut self,
a: &Tensor,
var: Option<Variable>,
is_tangent: bool,
) -> Result<Value> {
if let Some(var) = var {
self.tensor_ptr = Some(self.builder.use_var(var));
} else {
let name = if is_tangent {
self.get_tangent_tensor_name(a.name())
} else {
a.name().to_owned()
};
let res_ptr_var = *self
.variables
.get(name.as_str())
.unwrap_or_else(|| panic!("tensor {} not defined", a.name()));
let res_ptr = self.builder.use_var(res_ptr_var);
self.tensor_ptr = Some(res_ptr);
}
if a.rank() == 0 {
let mut exit_block = None;
if self.threaded {
let thread_id = self.variables.get("threadId").unwrap();
let thread_id = self.builder.use_var(*thread_id);
let is_first_thread = self.builder.ins().icmp_imm(IntCC::Equal, thread_id, 0);
exit_block = Some(self.builder.create_block());
let next_block = self.builder.create_block();
self.builder
.ins()
.brif(is_first_thread, next_block, &[], exit_block.unwrap(), &[]);
self.builder.seal_block(next_block);
self.builder.switch_to_block(next_block);
}
let elmt = a.elmts().first().unwrap();
let expr = if is_tangent {
elmt.tangent_expr()
} else {
elmt.expr()
};
let zero = self.builder.ins().iconst(self.int_type, 0);
let float_value = self.jit_compile_expr(a.name(), expr, &[], elmt, zero)?;
self.builder
.ins()
.store(self.mem_flags, float_value, self.tensor_ptr.unwrap(), 0);
if let Some(exit_block) = exit_block {
self.builder.ins().jump(exit_block, &[]);
self.builder.seal_block(exit_block);
self.builder.switch_to_block(exit_block);
}
} else {
for (i, blk) in a.elmts().iter().enumerate() {
let default = format!("{}-{}", a.name(), i);
let name = blk.name().unwrap_or(default.as_str());
self.jit_compile_block(name, a, blk, is_tangent)?;
}
}
Ok(self.tensor_ptr.unwrap())
}
fn jit_compile_block(
&mut self,
name: &str,
tensor: &Tensor,
elmt: &TensorBlock,
is_tangent: bool,
) -> Result<()> {
let translation = Translation::new(
elmt.expr_layout(),
elmt.layout(),
elmt.start(),
tensor.layout_ptr(),
);
if elmt.expr_layout().is_dense() {
self.jit_compile_dense_block(name, elmt, &translation, is_tangent)
} else if elmt.expr_layout().is_diagonal() {
self.jit_compile_diagonal_block(name, elmt, &translation, is_tangent)
} else if elmt.expr_layout().is_sparse() {
match translation.source {
TranslationFrom::SparseContraction { .. } => {
self.jit_compile_sparse_contraction_block(name, elmt, &translation, is_tangent)
}
_ => self.jit_compile_sparse_block(name, elmt, &translation, is_tangent),
}
} else {
Err(anyhow!(
"unsupported block layout: {:?}",
elmt.expr_layout()
))
}
}
fn decl_stack_slot(&mut self, ty: Type, val: Option<Value>) -> StackSlot {
let data = StackSlotData::new(StackSlotKind::ExplicitSlot, ty.bytes(), 0);
let ss = self.builder.create_sized_stack_slot(data);
if let Some(val) = val {
self.builder.ins().stack_store(val, ss, 0);
}
ss
}
fn jit_threading_limits(&mut self, size: Value) -> (Value, Value, Block) {
let one = self.builder.ins().iconst(self.int_type, 1);
let thread_id = self.variables.get("threadId").unwrap();
let thread_id = self.builder.use_var(*thread_id);
let thread_dim = self.variables.get("threadDim").unwrap();
let thread_dim = self.builder.use_var(*thread_dim);
let i_times_size = self.builder.ins().imul(thread_id, size);
let start = self.builder.ins().udiv(i_times_size, thread_dim);
let done = self
.builder
.ins()
.icmp(IntCC::UnsignedGreaterThanOrEqual, start, size);
let exit_block = self.builder.create_block();
let next_block = self.builder.create_block();
self.builder
.ins()
.brif(done, exit_block, &[], next_block, &[]);
self.builder.seal_block(next_block);
self.builder.switch_to_block(next_block);
let i_plus_one = self.builder.ins().iadd(thread_id, one);
let i_plus_one_times_size = self.builder.ins().imul(i_plus_one, size);
let end = self.builder.ins().udiv(i_plus_one_times_size, thread_dim);
let end_less_than_size = self.builder.ins().icmp(IntCC::UnsignedLessThan, end, size);
let end = self.builder.ins().select(end_less_than_size, end, size);
(start, end, exit_block)
}
fn jit_compile_dense_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
is_tangent: bool,
) -> Result<()> {
let int_type = self.int_type;
let expr_rank = elmt.expr_layout().rank();
let expr_shape = elmt
.expr_layout()
.shape()
.mapv(|n| i64::try_from(n).unwrap());
let one = self.builder.ins().iconst(int_type, 1);
let zero = self.builder.ins().iconst(int_type, 0);
let mut expr_strides = vec![1i64; expr_rank];
if expr_rank > 0 {
for i in (0..expr_rank - 1).rev() {
expr_strides[i] = expr_strides[i + 1] * expr_shape[i + 1];
}
}
let expr_strides = expr_strides
.iter()
.map(|&s| self.builder.ins().iconst(int_type, s))
.collect::<Vec<_>>();
let mut indices = Vec::new();
let mut blocks = Vec::new();
let (contract_sum, contract_by, contract_strides) =
if let TranslationFrom::DenseContraction {
contract_by,
contract_len: _,
} = translation.source
{
let contract_rank = expr_rank - contract_by;
let mut contract_strides = vec![1i64; contract_rank];
for i in (0..contract_rank - 1).rev() {
contract_strides[i] = contract_strides[i + 1] * expr_shape[i + 1];
}
let contract_strides = contract_strides
.iter()
.map(|&s| self.builder.ins().iconst(int_type, s))
.collect::<Vec<_>>();
(
Some(self.decl_stack_slot(self.real_type, None)),
contract_by,
Some(contract_strides),
)
} else {
(None, 0, None)
};
let (thread_start, thread_end, exit_block) = if self.threaded {
let expr_shape0 = self
.builder
.ins()
.iconst(int_type, *expr_shape.get(0).unwrap_or(&1));
let (start, end, exit_block) = self.jit_threading_limits(expr_shape0);
(Some(start), Some(end), Some(exit_block))
} else {
(None, None, None)
};
for i in 0..expr_rank {
let block = self.builder.create_block();
let curr_index = self.builder.append_block_param(block, self.int_type);
let curr_index_start = if i == 0 && self.threaded {
thread_start.unwrap()
} else {
zero
};
self.builder.ins().jump(block, &[curr_index_start.into()]);
self.builder.switch_to_block(block);
#[allow(clippy::unnecessary_unwrap)]
if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
let fzero = self.fconst(0.0);
self.builder
.ins()
.stack_store(fzero, contract_sum.unwrap(), 0);
}
indices.push(curr_index);
blocks.push(block);
}
let expr = if is_tangent {
elmt.tangent_expr()
} else {
elmt.expr()
};
let mut expr_index = *indices.last().unwrap_or(&zero);
let mut stride = 1u64;
if !indices.is_empty() {
for i in (0..indices.len() - 1).rev() {
let iname_i = indices[i];
let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
stride *= shapei;
let stride_intval = self
.builder
.ins()
.iconst(self.int_type, i64::try_from(stride).unwrap());
let stride_mul_i = self.builder.ins().imul(stride_intval, iname_i);
expr_index = self.builder.ins().iadd(expr_index, stride_mul_i);
}
}
let float_value =
self.jit_compile_expr(name, expr, indices.as_slice(), elmt, expr_index)?;
if let Some(contract_sum) = contract_sum {
let contract_sum_value = self
.builder
.ins()
.stack_load(self.real_type, contract_sum, 0);
let new_contract_sum_value = self.builder.ins().fadd(contract_sum_value, float_value);
self.builder
.ins()
.stack_store(new_contract_sum_value, contract_sum, 0);
} else {
let expr_index = indices
.iter()
.zip(expr_strides.iter())
.fold(zero, |acc, (i, s)| {
let tmp = self.builder.ins().imul(*i, *s);
self.builder.ins().iadd(acc, tmp)
});
self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
}
for i in (0..expr_rank).rev() {
#[allow(clippy::unnecessary_unwrap)]
if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
let contract_strides = contract_strides.as_ref().unwrap();
let elmt_index = indices
.iter()
.take(contract_strides.len())
.zip(contract_strides.iter())
.fold(zero, |acc, (i, s)| {
let tmp = self.builder.ins().imul(*i, *s);
self.builder.ins().iadd(acc, tmp)
});
let contract_sum_value =
self.builder
.ins()
.stack_load(self.real_type, contract_sum.unwrap(), 0);
self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?;
}
let next_index = self.builder.ins().iadd(indices[i], one);
let block = self.builder.create_block();
let loop_cond = if i == 0 && self.threaded {
self.builder
.ins()
.icmp(IntCC::UnsignedLessThan, next_index, thread_end.unwrap())
} else {
self.builder
.ins()
.icmp_imm(IntCC::UnsignedLessThan, next_index, expr_shape[i])
};
self.builder
.ins()
.brif(loop_cond, blocks[i], &[next_index.into()], block, &[]);
self.builder.seal_block(blocks[i]);
self.builder.seal_block(block);
self.builder.switch_to_block(block);
}
if let Some(exit_block) = exit_block {
self.builder.ins().jump(exit_block, &[]);
self.builder.seal_block(exit_block);
self.builder.switch_to_block(exit_block);
}
Ok(())
}
fn jit_compile_sparse_contraction_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
is_tangent: bool,
) -> Result<()> {
match translation.source {
TranslationFrom::SparseContraction { .. } => {}
_ => {
panic!("expected sparse contraction")
}
}
let int_type = self.int_type;
let zero = self.builder.ins().iconst(int_type, 0);
let one = self.builder.ins().iconst(int_type, 1);
let two = self.builder.ins().iconst(int_type, 2);
let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
let translation_index = self
.layout
.get_translation_index(elmt.expr_layout(), elmt.layout())
.unwrap();
let translation_index = translation_index + translation.get_from_index_in_data_layout();
let final_contract_index = self
.builder
.ins()
.iconst(int_type, i64::try_from(elmt.layout().nnz()).unwrap());
let (thread_start, thread_end, exit_block) = if self.threaded {
let (start, end, exit_block) = self.jit_threading_limits(final_contract_index);
(Some(start), Some(end), Some(exit_block))
} else {
(None, None, None)
};
let contract_sum_var = self.decl_stack_slot(self.real_type, None);
let block = self.builder.create_block();
let contract_index = self.builder.append_block_param(block, self.int_type);
self.builder
.ins()
.jump(block, &[thread_start.unwrap_or(zero).into()]);
self.builder.switch_to_block(block);
let translation_index_val = self
.builder
.ins()
.iconst(int_type, i64::try_from(translation_index).unwrap());
let double_contract_index = self.builder.ins().imul(two, contract_index);
let start_index = self
.builder
.ins()
.iadd(translation_index_val, double_contract_index);
let end_index = self.builder.ins().iadd(start_index, one);
let indices_array = self
.builder
.ins()
.global_value(self.int_ptr_type, self.indices);
let ptr = self.ptr_add_offset(self.int_type, indices_array, start_index);
let start_contract = self
.builder
.ins()
.load(self.int_type, self.mem_flags, ptr, 0);
let ptr = self.ptr_add_offset(self.int_type, indices_array, end_index);
let end_contract = self
.builder
.ins()
.load(self.int_type, self.mem_flags, ptr, 0);
let fzero = self.fconst(0.0);
self.builder.ins().stack_store(fzero, contract_sum_var, 0);
let start_contract_block = self.builder.create_block();
let expr_index = self
.builder
.append_block_param(start_contract_block, self.int_type);
self.builder
.ins()
.jump(start_contract_block, &[start_contract.into()]);
self.builder.switch_to_block(start_contract_block);
let rank_val = self.builder.ins().iconst(
self.int_type,
i64::try_from(elmt.expr_layout().rank()).unwrap(),
);
let elmt_index_mult_rank = self.builder.ins().imul(expr_index, rank_val);
let indices_int = (0..elmt.expr_layout().rank())
.map(|i| {
let layout_index_plus_offset = self
.builder
.ins()
.iconst(self.int_type, i64::try_from(layout_index + i).unwrap());
let curr_index = self
.builder
.ins()
.iadd(elmt_index_mult_rank, layout_index_plus_offset);
let ptr = self.ptr_add_offset(self.int_type, indices_array, curr_index);
let index = self
.builder
.ins()
.load(self.int_type, self.mem_flags, ptr, 0);
Ok(index)
})
.collect::<Result<Vec<_>, anyhow::Error>>()?;
let expr = if is_tangent {
elmt.tangent_expr()
} else {
elmt.expr()
};
let float_value =
self.jit_compile_expr(name, expr, indices_int.as_slice(), elmt, expr_index)?;
let contract_sum_value = self
.builder
.ins()
.stack_load(self.real_type, contract_sum_var, 0);
let new_contract_sum_value = self.builder.ins().fadd(contract_sum_value, float_value);
self.builder
.ins()
.stack_store(new_contract_sum_value, contract_sum_var, 0);
let next_elmt_index = self.builder.ins().iadd(expr_index, one);
let loop_while =
self.builder
.ins()
.icmp(IntCC::UnsignedLessThan, next_elmt_index, end_contract);
let post_contract_block = self.builder.create_block();
self.builder.ins().brif(
loop_while,
start_contract_block,
&[next_elmt_index.into()],
post_contract_block,
&[],
);
self.builder.seal_block(start_contract_block);
self.builder.seal_block(post_contract_block);
self.builder.switch_to_block(post_contract_block);
self.jit_compile_store(
name,
elmt,
contract_index,
new_contract_sum_value,
translation,
)?;
let next_contract_index = self.builder.ins().iadd(contract_index, one);
let loop_while = self.builder.ins().icmp(
IntCC::UnsignedLessThan,
next_contract_index,
thread_end.unwrap_or(final_contract_index),
);
let post_block = exit_block.unwrap_or(self.builder.create_block());
self.builder.ins().brif(
loop_while,
block,
&[next_contract_index.into()],
post_block,
&[],
);
self.builder.seal_block(block);
self.builder.switch_to_block(post_block);
self.builder.seal_block(post_block);
Ok(())
}
fn jit_compile_sparse_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
is_tangent: bool,
) -> Result<()> {
let int_type = self.int_type;
let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
let zero = self.builder.ins().iconst(int_type, 0);
let one = self.builder.ins().iconst(int_type, 1);
let end_index = self
.builder
.ins()
.iconst(int_type, i64::try_from(elmt.layout().nnz()).unwrap());
let (thread_start, thread_end, exit_block) = if self.threaded {
let (start, end, exit_block) = self.jit_threading_limits(end_index);
(Some(start), Some(end), Some(exit_block))
} else {
(None, None, None)
};
let loop_start_block = self.builder.create_block();
let curr_index = self.builder.append_block_param(loop_start_block, int_type);
self.builder
.ins()
.jump(loop_start_block, &[thread_start.unwrap_or(zero).into()]);
self.builder.switch_to_block(loop_start_block);
let elmt_index = curr_index;
let rank_val = self
.builder
.ins()
.iconst(int_type, i64::try_from(elmt.expr_layout().rank()).unwrap());
let elmt_index_mult_rank = self.builder.ins().imul(elmt_index, rank_val);
let indices_int = (0..elmt.expr_layout().rank())
.map(|i| {
let layout_index_plus_offset = self
.builder
.ins()
.iconst(int_type, i64::try_from(layout_index + i).unwrap());
let curr_index = self
.builder
.ins()
.iadd(elmt_index_mult_rank, layout_index_plus_offset);
let indices_ptr = self
.builder
.ins()
.global_value(self.int_ptr_type, self.indices);
let ptr = self.ptr_add_offset(self.int_type, indices_ptr, curr_index);
let index = self
.builder
.ins()
.load(self.int_type, self.mem_flags, ptr, 0);
Ok(index)
})
.collect::<Result<Vec<_>, anyhow::Error>>()?;
let expr = if is_tangent {
elmt.tangent_expr()
} else {
elmt.expr()
};
let float_value =
self.jit_compile_expr(name, expr, indices_int.as_slice(), elmt, elmt_index)?;
self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
let next_index = self.builder.ins().iadd(elmt_index, one);
let loop_while = self.builder.ins().icmp(
IntCC::UnsignedLessThan,
next_index,
thread_end.unwrap_or(end_index),
);
let post_block = exit_block.unwrap_or(self.builder.create_block());
self.builder.ins().brif(
loop_while,
loop_start_block,
&[next_index.into()],
post_block,
&[],
);
self.builder.seal_block(loop_start_block);
self.builder.switch_to_block(post_block);
self.builder.seal_block(post_block);
Ok(())
}
fn jit_compile_diagonal_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
is_tangent: bool,
) -> Result<()> {
let int_type = self.int_type;
let zero = self.builder.ins().iconst(int_type, 0);
let one = self.builder.ins().iconst(int_type, 1);
let block = self.builder.create_block();
let end_index = self
.builder
.ins()
.iconst(int_type, i64::try_from(elmt.expr_layout().nnz()).unwrap());
let (thread_start, thread_end, exit_block) = if self.threaded {
let (start, end, exit_block) = self.jit_threading_limits(end_index);
(Some(start), Some(end), Some(exit_block))
} else {
(None, None, None)
};
let curr_index = self.builder.append_block_param(block, int_type);
self.builder
.ins()
.jump(block, &[thread_start.unwrap_or(zero).into()]);
self.builder.switch_to_block(block);
let elmt_index = curr_index;
let indices_int = vec![elmt_index; elmt.expr_layout().rank()];
let expr = if is_tangent {
elmt.tangent_expr()
} else {
elmt.expr()
};
let float_value =
self.jit_compile_expr(name, expr, indices_int.as_slice(), elmt, elmt_index)?;
self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
let next_index = self.builder.ins().iadd(elmt_index, one);
let loop_while = self.builder.ins().icmp(
IntCC::UnsignedLessThan,
next_index,
thread_end.unwrap_or(end_index),
);
let post_block = exit_block.unwrap_or(self.builder.create_block());
self.builder
.ins()
.brif(loop_while, block, &[next_index.into()], post_block, &[]);
self.builder.seal_block(block);
self.builder.switch_to_block(post_block);
self.builder.seal_block(post_block);
Ok(())
}
fn jit_compile_broadcast_and_store(
&mut self,
name: &str,
elmt: &TensorBlock,
float_value: Value,
expr_index: Value,
translation: &Translation,
) -> Result<Block> {
let int_type = self.int_type;
let one = self.builder.ins().iconst(int_type, 1);
let zero = self.builder.ins().iconst(int_type, 0);
let pre_block = self.builder.current_block().unwrap();
match translation.source {
TranslationFrom::Broadcast {
broadcast_by: _,
broadcast_len,
} => {
let bcast_block = self.builder.create_block();
let bcast_start_index = zero;
let bcast_end_index = self
.builder
.ins()
.iconst(int_type, i64::try_from(broadcast_len).unwrap());
let bcast_index = self.builder.append_block_param(bcast_block, self.int_type);
self.builder
.ins()
.jump(bcast_block, &[bcast_start_index.into()]);
self.builder.switch_to_block(bcast_block);
let tmp = self.builder.ins().imul(expr_index, bcast_end_index);
let store_index = self.builder.ins().iadd(tmp, bcast_index);
self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
let bcast_next_index = self.builder.ins().iadd(bcast_index, one);
let bcast_cond = self.builder.ins().icmp(
IntCC::UnsignedLessThan,
bcast_next_index,
bcast_end_index,
);
let post_bcast_block = self.builder.create_block();
self.builder.ins().brif(
bcast_cond,
bcast_block,
&[bcast_next_index.into()],
post_bcast_block,
&[],
);
self.builder.seal_block(bcast_block);
self.builder.seal_block(post_bcast_block);
self.builder.switch_to_block(post_bcast_block);
Ok(post_bcast_block)
}
TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
Ok(pre_block)
}
_ => Err(anyhow!("Invalid translation")),
}
}
fn jit_compile_store(
&mut self,
_name: &str,
elmt: &TensorBlock,
store_index: Value,
float_value: Value,
translation: &Translation,
) -> Result<()> {
let int_type = self.int_type;
let res_index = match &translation.target {
TranslationTo::Contiguous { start, end: _ } => {
let start_const = self
.builder
.ins()
.iconst(int_type, i64::try_from(*start).unwrap());
self.builder.ins().iadd(start_const, store_index)
}
TranslationTo::Sparse { indices: _ } => {
let translate_index = self
.layout
.get_translation_index(elmt.expr_layout(), elmt.layout())
.unwrap();
let translate_store_index =
translate_index + translation.get_to_index_in_data_layout();
let translate_store_index = self
.builder
.ins()
.iconst(int_type, i64::try_from(translate_store_index).unwrap());
let elmt_index_strided = store_index;
let curr_index = self
.builder
.ins()
.iadd(elmt_index_strided, translate_store_index);
let indices_ptr = self
.builder
.ins()
.global_value(self.int_ptr_type, self.indices);
let ptr = self.ptr_add_offset(self.int_type, indices_ptr, curr_index);
self.builder
.ins()
.load(self.int_type, self.mem_flags, ptr, 0)
}
};
let ptr = self.ptr_add_offset(self.real_type, self.tensor_ptr.unwrap(), res_index);
self.builder
.ins()
.store(self.mem_flags, float_value, ptr, 0);
Ok(())
}
fn declare_variable(&mut self, ty: types::Type, name: &str, val: Value) -> Variable {
if !self.variables.contains_key(name) {
let var = self.builder.declare_var(ty);
self.builder.def_var(var, val);
self.variables.insert(name.into(), var);
var
} else {
*self.variables.get(name).unwrap()
}
}
fn get_tangent_tensor_name(&self, name: &str) -> String {
format!("{name}__tangent__")
}
fn insert_tensor(&mut self, tensor: &Tensor, ptr: Value, data_index: i64, is_tangent: bool) {
let mut tensor_data_index = data_index;
let tensor_data_ptr = self.ptr_add_offset_i64(self.real_type, ptr, tensor_data_index);
let tensor_name = if is_tangent {
self.get_tangent_tensor_name(tensor.name())
} else {
tensor.name().to_owned()
};
self.declare_variable(self.real_ptr_type, tensor_name.as_str(), tensor_data_ptr);
for blk in tensor.elmts() {
if let Some(name) = blk.name() {
let blk_name = if is_tangent {
self.get_tangent_tensor_name(name)
} else {
name.to_owned()
};
let tensor_data_ptr =
self.ptr_add_offset_i64(self.real_type, ptr, tensor_data_index);
self.declare_variable(self.real_ptr_type, blk_name.as_str(), tensor_data_ptr);
}
tensor_data_index += i64::try_from(blk.nnz()).unwrap();
}
}
pub fn new(
module: &'ctx mut CraneliftModule<M>,
model: &DiscreteModel,
arg_names: &[&str],
arg_types: &[Type],
) -> Self {
module.ctx.func.signature.params.clear();
module.ctx.func.signature.returns.clear();
for ty in arg_types {
module.ctx.func.signature.params.push(AbiParam::new(*ty));
}
let mut builder = FunctionBuilder::new(&mut module.ctx.func, &mut module.builder_context);
let indices = module
.module
.lock()
.unwrap()
.declare_data_in_func(module.indices_id, builder.func);
let constants = module
.module
.lock()
.unwrap()
.declare_data_in_func(module.constants_id, builder.func);
let model_index_global = module
.module
.lock()
.unwrap()
.declare_data_in_func(module.model_index_id, builder.func);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let mut codegen = Self {
int_type: module.int_type,
real_type: module.real_type,
real_ptr_type: module.real_ptr_type,
int_ptr_type: module.int_ptr_type,
builder,
module: module.module.lock().unwrap(),
tensor_ptr: None,
indices,
constants,
variables: HashMap::new(),
mem_flags: MemFlags::new(),
functions: HashMap::new(),
layout: &module.layout,
threaded: module.threaded,
model_index_global,
};
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types.iter()).enumerate() {
let val = codegen.builder.block_params(entry_block)[i];
codegen.declare_variable(*arg_type, arg_name, val);
}
if !codegen.variables.contains_key("model_index") {
let model_index_ptr = codegen
.builder
.ins()
.global_value(codegen.int_ptr_type, codegen.model_index_global);
let model_index =
codegen
.builder
.ins()
.load(codegen.int_type, codegen.mem_flags, model_index_ptr, 0);
codegen.declare_variable(codegen.int_type, "model_index", model_index);
}
if let Some(u) = codegen.variables.get("u") {
let u_ptr = codegen.builder.use_var(*u);
codegen.insert_tensor(model.state(), u_ptr, 0, false);
}
if let Some(du) = codegen.variables.get("du") {
let du_ptr = codegen.builder.use_var(*du);
codegen.insert_tensor(model.state(), du_ptr, 0, true);
}
if let Some(dudt) = codegen.variables.get("dudt") {
if let Some(state_dot) = model.state_dot() {
let statedot_ptr = codegen.builder.use_var(*dudt);
codegen.insert_tensor(state_dot, statedot_ptr, 0, false);
}
}
if let Some(out_var) = codegen.variables.get("out") {
if let Some(out) = model.out() {
let out_ptr = codegen.builder.use_var(*out_var);
codegen.insert_tensor(out, out_ptr, 0, false);
}
}
if let Some(dout) = codegen.variables.get("dout") {
if let Some(out) = model.out() {
let dout_ptr = codegen.builder.use_var(*dout);
codegen.insert_tensor(out, dout_ptr, 0, true);
}
}
let constants = codegen
.builder
.ins()
.global_value(codegen.real_ptr_type, codegen.constants);
for tensor in model.constant_defns() {
let data_index =
i64::try_from(codegen.layout.get_data_index(tensor.name()).unwrap()).unwrap();
codegen.insert_tensor(tensor, constants, data_index, false);
}
let tensors = model.input().into_iter();
let tensors = tensors.chain(model.input_dep_defns().iter());
let tensors = tensors.chain(model.time_dep_defns().iter());
let tensors = tensors.chain(model.state_dep_defns().iter());
let tensors = tensors.chain(model.state_dep_post_f_defns().iter());
if let Some(data) = codegen.variables.get("data") {
let data_ptr = codegen.builder.use_var(*data);
for tensor in tensors.clone() {
let data_index =
i64::try_from(codegen.layout.get_data_index(tensor.name()).unwrap()).unwrap();
codegen.insert_tensor(tensor, data_ptr, data_index, false);
}
}
if let Some(data) = codegen.variables.get("ddata") {
let data_ptr = codegen.builder.use_var(*data);
for tensor in tensors {
let data_index =
i64::try_from(codegen.layout.get_data_index(tensor.name()).unwrap()).unwrap();
codegen.insert_tensor(tensor, data_ptr, data_index, true);
}
}
codegen
}
fn jit_compile_inputs(
&mut self,
model: &DiscreteModel,
base_data_ptr: Value,
is_tangent: bool,
is_get: bool,
) {
let inputs_index = 0;
if let Some(input) = model.input() {
let data_index =
i64::try_from(self.layout.get_data_index(input.name()).unwrap()).unwrap();
self.insert_tensor(input, base_data_ptr, data_index, is_tangent);
let tensor_name = if is_tangent {
self.get_tangent_tensor_name(input.name())
} else {
input.name().to_owned()
};
let data_ptr = self.variables.get(tensor_name.as_str()).unwrap();
let data_ptr = self.builder.use_var(*data_ptr);
let input_name = if is_tangent { "dinputs" } else { "inputs" };
let input_ptr = self.variables.get(input_name).unwrap();
let input_ptr = self.builder.use_var(*input_ptr);
let inputs_start_index = self
.builder
.ins()
.iconst(self.int_type, i64::from(inputs_index));
let start_index = self.builder.ins().iconst(self.int_type, 0);
let input_block = self.builder.create_block();
let curr_input_index = self.builder.append_block_param(input_block, self.int_type);
self.builder.ins().jump(input_block, &[start_index.into()]);
self.builder.switch_to_block(input_block);
let curr_input_index_plus_start_index = self
.builder
.ins()
.iadd(curr_input_index, inputs_start_index);
let indexed_input_ptr =
self.ptr_add_offset(self.real_type, input_ptr, curr_input_index_plus_start_index);
let indexed_data_ptr = self.ptr_add_offset(self.real_type, data_ptr, curr_input_index);
if is_get {
let input_value =
self.builder
.ins()
.load(self.real_type, self.mem_flags, indexed_data_ptr, 0);
self.builder
.ins()
.store(self.mem_flags, input_value, indexed_input_ptr, 0);
} else {
let input_value =
self.builder
.ins()
.load(self.real_type, self.mem_flags, indexed_input_ptr, 0);
self.builder
.ins()
.store(self.mem_flags, input_value, indexed_data_ptr, 0);
}
let one = self.builder.ins().iconst(self.int_type, 1);
let next_index = self.builder.ins().iadd(curr_input_index, one);
let loop_while = self.builder.ins().icmp_imm(
IntCC::UnsignedLessThan,
next_index,
i64::try_from(input.nnz()).unwrap(),
);
let post_block = self.builder.create_block();
self.builder.ins().brif(
loop_while,
input_block,
&[next_index.into()],
post_block,
&[],
);
self.builder.seal_block(input_block);
self.builder.seal_block(post_block);
self.builder.switch_to_block(post_block);
}
}
}