use aliasable::boxed::AliasableBox;
use anyhow::{anyhow, Result};
use core::panic;
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::basic_block::BasicBlock;
use inkwell::builder::Builder;
use inkwell::context::{AsContextRef, Context};
use inkwell::debug_info::AsDIScope;
use inkwell::debug_info::{DICompileUnit, DIFlags, DIFlagsConstants, DebugInfoBuilder};
use inkwell::execution_engine::ExecutionEngine;
use inkwell::intrinsics::Intrinsic;
use inkwell::module::{Linkage, Module};
use inkwell::passes::PassBuilderOptions;
use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
use inkwell::types::{
BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
};
use inkwell::values::{
AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
FunctionValue, GlobalValue, IntValue, PointerValue,
};
use inkwell::{
AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
OptimizationLevel,
};
use llvm_sys::core::{
LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
};
use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
use pest::Span;
use std::collections::HashMap;
use std::ffi::CString;
use std::iter::zip;
use std::pin::Pin;
use target_lexicon::Triple;
use crate::ast::{Ast, AstKind};
use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
use crate::enzyme::{
CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
};
use crate::execution::compiler::CompilerOptions;
use crate::execution::module::{
CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
};
use crate::execution::scalar::RealType;
use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
use lazy_static::lazy_static;
use std::sync::Mutex;
lazy_static! {
static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
}
struct ImmovableLlvmModule {
// actually has lifetime of `context`
// declared first so it's droped before `context`
codegen: Option<CodeGen<'static>>,
// safety: we must never move out of this box as long as codgen is alive
context: AliasableBox<Context>,
_pin: std::marker::PhantomPinned,
}
pub struct LlvmModule {
inner: Pin<Box<ImmovableLlvmModule>>,
machine: TargetMachine,
}
unsafe impl Send for LlvmModule {}
unsafe impl Sync for LlvmModule {}
impl LlvmModule {
fn new(
triple: Option<Triple>,
model: &DiscreteModel,
threaded: bool,
real_type: RealType,
debug: bool,
) -> Result<Self> {
let initialization_config = &InitializationConfig::default();
Target::initialize_all(initialization_config);
let host_triple = Triple::host();
let (triple_str, native) = match triple {
Some(ref triple) => (triple.to_string(), false),
None => (host_triple.to_string(), true),
};
let triple = TargetTriple::create(triple_str.as_str());
let target = Target::from_triple(&triple).unwrap();
let cpu = if native {
TargetMachine::get_host_cpu_name().to_string()
} else {
"generic".to_string()
};
let features = if native {
TargetMachine::get_host_cpu_features().to_string()
} else {
"".to_string()
};
let machine = target
.create_target_machine(
&triple,
cpu.as_str(),
features.as_str(),
inkwell::OptimizationLevel::Aggressive,
inkwell::targets::RelocMode::Default,
inkwell::targets::CodeModel::Default,
)
.unwrap();
let context = AliasableBox::from_unique(Box::new(Context::create()));
let mut pinned = Self {
inner: Box::pin(ImmovableLlvmModule {
codegen: None,
context,
_pin: std::marker::PhantomPinned,
}),
machine,
};
let context_ref = pinned.inner.context.as_ref();
let real_type_llvm = match real_type {
RealType::F32 => context_ref.f32_type(),
RealType::F64 => context_ref.f64_type(),
};
let int_type_llvm = context_ref.i32_type();
let ptr_size_bits = pinned
.machine
.get_target_data()
.get_bit_size(&context_ref.ptr_type(AddressSpace::default()));
let real_size_bits = pinned
.machine
.get_target_data()
.get_bit_size(&real_type_llvm);
let int_size_bits = pinned
.machine
.get_target_data()
.get_bit_size(&int_type_llvm);
let codegen = CodeGen::new(
model,
context_ref,
real_type,
real_type_llvm,
int_type_llvm,
threaded,
ptr_size_bits,
real_size_bits,
int_size_bits,
debug,
)?;
let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
Ok(pinned)
}
fn pre_autodiff_optimisation(&mut self) -> Result<()> {
//let pass_manager_builder = PassManagerBuilder::create();
//pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default);
//let pass_manager = PassManager::create(());
//pass_manager_builder.populate_module_pass_manager(&pass_manager);
//pass_manager.run_on(self.codegen().module());
//self.codegen().module().print_to_stderr();
// optimise at -O2 no unrolling before giving to enzyme
let pass_options = PassBuilderOptions::create();
//pass_options.set_verify_each(true);
//pass_options.set_debug_logging(true);
//pass_options.set_loop_interleaving(true);
pass_options.set_loop_vectorization(false);
pass_options.set_loop_slp_vectorization(false);
pass_options.set_loop_unrolling(false);
//pass_options.set_forget_all_scev_in_loop_unroll(true);
//pass_options.set_licm_mssa_opt_cap(1);
//pass_options.set_licm_mssa_no_acc_for_promotion_cap(10);
//pass_options.set_call_graph_profile(true);
//pass_options.set_merge_functions(true);
//let path = "jit_module_before_pre_autodiff_opt.ll";
//self.codegen()
// .module()
// .print_to_file(path)
// .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))?;
//let passes = "default<O2>";
let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
let (codegen, machine) = self.codegen_and_machine_mut();
codegen
.module()
.run_passes(passes, machine, pass_options)
.map_err(|e| anyhow!("Failed to run passes: {:?}", e))
//let path = "jit_module_after_pre_autodiff_opt.ll";
//self.codegen()
// .module()
// .print_to_file(path)
// .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))
}
fn post_autodiff_optimisation(&mut self) -> Result<()> {
// remove noinline attribute from barrier function as only needed for enzyme
if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
}
// remove all preprocess_* functions
for f in self.codegen_mut().module.get_functions() {
if f.get_name().to_str().unwrap().starts_with("preprocess_") {
unsafe { f.delete() };
}
}
//self.codegen()
// .module()
// .print_to_file("jit_module_before_post_autodiff_opt.ll")
// .unwrap();
let passes = "default<O3>";
let (codegen, machine) = self.codegen_and_machine_mut();
codegen
.module()
.run_passes(passes, machine, PassBuilderOptions::create())
.map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
//self.codegen()
// .module()
// .print_to_file("jit_module_after_post_autodiff_opt.ll")
// .unwrap();
Ok(())
}
pub fn print(&self) {
self.codegen().module().print_to_stderr();
}
fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
unsafe {
self.inner
.as_mut()
.get_unchecked_mut()
.codegen
.as_mut()
.unwrap()
}
}
fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
(
unsafe {
self.inner
.as_mut()
.get_unchecked_mut()
.codegen
.as_mut()
.unwrap()
},
&self.machine,
)
}
fn codegen(&self) -> &CodeGen<'static> {
self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
}
pub fn to_dynamic_library(self, output_path: impl Into<std::path::PathBuf>) -> Result<()> {
use std::fs;
use std::process::Command;
let output_path = output_path.into();
let object_buffer = self.to_object()?;
// Create a temporary object file.
let temp_dir = std::env::temp_dir();
let obj_path = temp_dir.join("diffsl_temp_object.o");
fs::write(&obj_path, object_buffer)
.map_err(|e| anyhow!("Failed to write temporary object file: {}", e))?;
let lld = option_env!("DIFFSL_LLVM_LLD")
.ok_or_else(|| anyhow!("DIFFSL_LLVM_LLD not set by build script"))?;
let linker_name = std::path::Path::new(lld)
.file_name()
.and_then(|name| name.to_str())
.unwrap_or(lld);
let is_clang_driver = linker_name.starts_with("clang");
let mut command = Command::new(lld);
if cfg!(target_os = "windows") {
if is_clang_driver {
command.arg("-shared");
command.arg("-o");
command.arg(&output_path);
command.arg(&obj_path);
} else {
command.arg("-flavor").arg("link");
command.arg("/DLL");
command.arg(format!("/OUT:{}", output_path.display()));
command.arg(&obj_path);
}
} else if cfg!(target_os = "macos") {
if !lld.ends_with("ld64.lld") && !is_clang_driver {
command.arg("-flavor").arg("darwin");
}
let arch = if cfg!(target_arch = "aarch64") {
"arm64"
} else if cfg!(target_arch = "x86_64") {
"x86_64"
} else {
return Err(anyhow!("Unsupported macOS architecture for lld invocation"));
};
let deployment_target =
std::env::var("MACOSX_DEPLOYMENT_TARGET").unwrap_or_else(|_| "11.0".to_string());
if is_clang_driver {
command.arg("-dynamiclib");
command.arg("-arch");
command.arg(arch);
command.arg(format!("-mmacosx-version-min={deployment_target}"));
command.arg("-o");
command.arg(&output_path);
command.arg(&obj_path);
} else {
command.arg("-arch");
command.arg(arch);
command.arg("-platform_version");
command.arg("macos");
command.arg(&deployment_target);
command.arg(&deployment_target);
command.arg("-dylib");
command.arg("-o");
command.arg(&output_path);
command.arg(&obj_path);
}
} else {
if is_clang_driver {
command.arg("-shared");
command.arg("-o");
command.arg(&output_path);
command.arg(&obj_path);
} else {
command.arg("-flavor").arg("gnu");
command.arg("-shared");
command.arg("-o");
command.arg(&output_path);
command.arg(&obj_path);
}
}
let status = command
.status()
.map_err(|e| anyhow!("Failed to invoke lld: {}", e))?;
if !status.success() {
return Err(anyhow!(
"Dynamic library link failed with status: {}",
status
));
}
let _ = fs::remove_file(&obj_path);
Ok(())
}
}
impl CodegenModule for LlvmModule {}
impl CodegenModuleCompile for LlvmModule {
fn from_discrete_model(
model: &DiscreteModel,
options: CompilerOptions,
triple: Option<Triple>,
real_type: RealType,
code: Option<&str>,
) -> Result<Self> {
let thread_dim = options.mode.thread_dim(model.state().nnz());
let threaded = thread_dim > 1;
if (unsafe { LLVMIsMultithreaded() } <= 0) {
return Err(anyhow!(
"LLVM is not compiled with multithreading support, but this codegen module requires it."
));
}
let mut module = Self::new(triple, model, threaded, real_type, options.debug)?;
let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
let calc_stop = module.codegen_mut().compile_calc_stop(model, false, code)?;
let calc_stop_full = module.codegen_mut().compile_calc_stop(model, true, code)?;
let reset = module.codegen_mut().compile_reset(model, false, code)?;
let reset_full = module.codegen_mut().compile_reset(model, true, code)?;
let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
let mass = module.codegen_mut().compile_mass(model, code)?;
let calc_out = module.codegen_mut().compile_calc_out(model, false, code)?;
let calc_out_full = module.codegen_mut().compile_calc_out(model, true, code)?;
let _set_id = module.codegen_mut().compile_set_id(model)?;
let _get_dims = module.codegen_mut().compile_get_dims(model)?;
let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
let _set_constants = module.codegen_mut().compile_set_constants(model, code)?;
let tensor_info = module
.codegen()
.layout
.tensors()
.map(|(name, is_constant)| (name.to_string(), is_constant))
.collect::<Vec<_>>();
for (tensor, is_constant) in tensor_info {
if is_constant {
module
.codegen_mut()
.compile_get_constant(model, tensor.as_str())?;
} else {
module
.codegen_mut()
.compile_get_tensor(model, tensor.as_str())?;
}
}
module.pre_autodiff_optimisation()?;
module.codegen_mut().compile_gradient(
set_u0,
&[
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Forward,
"set_u0_grad",
)?;
module.codegen_mut().compile_gradient(
rhs,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Forward,
"rhs_grad",
)?;
module.codegen_mut().compile_gradient(
reset,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Forward,
"reset_grad",
)?;
module.codegen_mut().compile_gradient(
calc_stop,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Forward,
"calc_stop_grad",
)?;
module.codegen_mut().compile_gradient(
calc_out,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Forward,
"calc_out_grad",
)?;
module.codegen_mut().compile_gradient(
set_inputs,
&[
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
],
CompileMode::Forward,
"set_inputs_grad",
)?;
module.codegen_mut().compile_gradient(
set_u0,
&[
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"set_u0_rgrad",
)?;
module.codegen_mut().compile_gradient(
mass,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"mass_rgrad",
)?;
module.codegen_mut().compile_gradient(
rhs,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"rhs_rgrad",
)?;
module.codegen_mut().compile_gradient(
reset,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"reset_rgrad",
)?;
module.codegen_mut().compile_gradient(
calc_stop,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"calc_stop_rgrad",
)?;
module.codegen_mut().compile_gradient(
calc_out,
&[
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"calc_out_rgrad",
)?;
module.codegen_mut().compile_gradient(
set_inputs,
&[
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
],
CompileMode::Reverse,
"set_inputs_rgrad",
)?;
module.codegen_mut().compile_gradient(
rhs_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ForwardSens,
"rhs_sgrad",
)?;
module.codegen_mut().compile_gradient(
reset_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ForwardSens,
"reset_sgrad",
)?;
module.codegen_mut().compile_gradient(
calc_stop_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ForwardSens,
"calc_stop_sgrad",
)?;
module.codegen_mut().compile_gradient(
set_u0,
&[
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ForwardSens,
"set_u0_sgrad",
)?;
module.codegen_mut().compile_gradient(
calc_out_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ForwardSens,
"calc_out_sgrad",
)?;
module.codegen_mut().compile_gradient(
calc_out_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ReverseSens,
"calc_out_srgrad",
)?;
module.codegen_mut().compile_gradient(
rhs_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ReverseSens,
"rhs_srgrad",
)?;
module.codegen_mut().compile_gradient(
reset_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ReverseSens,
"reset_srgrad",
)?;
module.codegen_mut().compile_gradient(
calc_stop_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::DupNoNeed,
CompileGradientArgType::Const,
CompileGradientArgType::Const,
],
CompileMode::ReverseSens,
"calc_stop_srgrad",
)?;
module.post_autodiff_optimisation()?;
Ok(module)
}
}
impl CodegenModuleEmit for LlvmModule {
fn to_object(&self) -> Result<Vec<u8>> {
let module = self.codegen().module();
//module.print_to_stderr();
let buffer = self
.machine
.write_to_memory_buffer(module, FileType::Object)
.unwrap()
.as_slice()
.to_vec();
Ok(buffer)
}
}
impl CodegenModuleJit for LlvmModule {
fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
let ee = self
.codegen()
.module()
.create_jit_execution_engine(OptimizationLevel::Default)
.map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
//let path = "jit_module.ll";
//self.codegen()
// .module()
// .print_to_file(path)
// .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))?;
let module = self.codegen().module();
let mut symbols = HashMap::new();
for function in module.get_functions() {
let name = function.get_name().to_str().unwrap();
let address = ee.get_function_address(name);
if let Ok(address) = address {
symbols.insert(name.to_string(), address as *const u8);
}
}
Ok(symbols)
}
}
struct Globals<'ctx> {
indices: Option<GlobalValue<'ctx>>,
constants: Option<GlobalValue<'ctx>>,
thread_counter: Option<GlobalValue<'ctx>>,
model_index: GlobalValue<'ctx>,
}
impl<'ctx> Globals<'ctx> {
fn new(
layout: &DataLayout,
module: &Module<'ctx>,
int_type: IntType<'ctx>,
real_type: FloatType<'ctx>,
threaded: bool,
) -> Self {
let thread_counter = if threaded {
let tc = module.add_global(
int_type,
Some(AddressSpace::default()),
"enzyme_const_thread_counter",
);
// todo: for some reason this doesn't make enzyme think it's inactive
// but using enzyme_const in the name does
// todo: also, adding this metadata causes the print of the module to segfault,
// so maybe a bug in inkwell
//let md_string = context.metadata_string("enzyme_inactive");
//tc.set_metadata(md_string, 0);
let tc_value = int_type.const_zero();
tc.set_visibility(GlobalVisibility::Hidden);
tc.set_initializer(&tc_value.as_basic_value_enum());
Some(tc)
} else {
None
};
let constants = if layout.constants().is_empty() {
None
} else {
let constants_array_type =
real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
let constants = module.add_global(
constants_array_type,
Some(AddressSpace::default()),
"enzyme_const_constants",
);
constants.set_visibility(GlobalVisibility::Hidden);
constants.set_constant(false);
constants.set_initializer(&constants_array_type.const_zero());
Some(constants)
};
let indices = if layout.indices().is_empty() {
None
} else {
let indices_array_type =
int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
let indices_array_values = layout
.indices()
.iter()
.map(|&i| int_type.const_int(i as u64, true))
.collect::<Vec<IntValue>>();
let indices_value = int_type.const_array(indices_array_values.as_slice());
let indices = module.add_global(
indices_array_type,
Some(AddressSpace::default()),
"enzyme_const_indices",
);
indices.set_constant(true);
indices.set_visibility(GlobalVisibility::Hidden);
indices.set_initializer(&indices_value);
Some(indices)
};
let model_index = module.add_global(
int_type,
Some(AddressSpace::default()),
"enzyme_const_model_index",
);
model_index.set_visibility(GlobalVisibility::Hidden);
model_index.set_constant(false);
model_index.set_initializer(&int_type.const_zero());
Self {
indices,
thread_counter,
constants,
model_index,
}
}
}
pub enum CompileGradientArgType {
Const,
Dup,
DupNoNeed,
}
pub enum CompileMode {
Forward,
ForwardSens,
Reverse,
ReverseSens,
}
pub struct CodeGen<'ctx> {
context: &'ctx inkwell::context::Context,
module: Module<'ctx>,
builder: Builder<'ctx>,
dibuilder: Option<DebugInfoBuilder<'ctx>>,
compile_unit: Option<DICompileUnit<'ctx>>,
variables: HashMap<String, PointerValue<'ctx>>,
functions: HashMap<String, FunctionValue<'ctx>>,
fn_value_opt: Option<FunctionValue<'ctx>>,
tensor_ptr_opt: Option<PointerValue<'ctx>>,
diffsl_real_type: RealType,
real_type: FloatType<'ctx>,
real_ptr_type: PointerType<'ctx>,
int_type: IntType<'ctx>,
int_ptr_type: PointerType<'ctx>,
ptr_size_bits: u64,
int_size_bits: u64,
real_size_bits: u64,
layout: DataLayout,
globals: Globals<'ctx>,
threaded: bool,
_ee: Option<ExecutionEngine<'ctx>>,
}
unsafe extern "C" fn fwd_handler(
_builder: LLVMBuilderRef,
_call_instruction: LLVMValueRef,
_gutils: *mut GradientUtils,
_dcall: *mut LLVMValueRef,
_normal_return: *mut LLVMValueRef,
_shadow_return: *mut LLVMValueRef,
) -> u8 {
1
}
unsafe extern "C" fn rev_handler(
builder: LLVMBuilderRef,
call_instruction: LLVMValueRef,
gutils: *mut DiffeGradientUtils,
_tape: LLVMValueRef,
) {
let call_block = LLVMGetInstructionParent(call_instruction);
let call_function = LLVMGetBasicBlockParent(call_block);
let module = LLVMGetGlobalParent(call_function);
let name_c_str = CString::new("barrier_grad").unwrap();
let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
let barrier_num = LLVMGetArgOperand(call_instruction, 0);
let total_barriers = LLVMGetArgOperand(call_instruction, 1);
let thread_count = LLVMGetArgOperand(call_instruction, 2);
let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
let total_barriers =
EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
let thread_count =
EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
let mut args = [barrier_num, total_barriers, thread_count];
let name_c_str = CString::new("").unwrap();
LLVMBuildCall2(
builder,
barrier_func_type,
barrier_func,
args.as_mut_ptr(),
args.len() as u32,
name_c_str.as_ptr(),
);
}
#[allow(dead_code)]
enum PrintValue<'ctx> {
Real(FloatValue<'ctx>),
Int(IntValue<'ctx>),
}
impl<'ctx> CodeGen<'ctx> {
#[allow(clippy::too_many_arguments)]
pub fn new(
model: &DiscreteModel,
context: &'ctx inkwell::context::Context,
diffsl_real_type: RealType,
real_type: FloatType<'ctx>,
int_type: IntType<'ctx>,
threaded: bool,
ptr_size_bits: u64,
real_size_bits: u64,
int_size_bits: u64,
debug: bool,
) -> Result<Self> {
let builder = context.create_builder();
let layout = DataLayout::new(model);
let module = context.create_module(model.name());
let (dibuilder, compile_unit) = if debug {
let (dib, dic) = module.create_debug_info_builder(
true,
inkwell::debug_info::DWARFSourceLanguage::C,
model.name(),
".",
"diffsl compiler",
true,
"",
0,
"",
inkwell::debug_info::DWARFEmissionKind::Full,
0,
false,
false,
"",
"",
);
(Some(dib), Some(dic))
} else {
(None, None)
};
let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
let real_ptr_type = Self::pointer_type(context, real_type.into());
let int_ptr_type = Self::pointer_type(context, int_type.into());
let mut ret = Self {
context,
module,
builder,
dibuilder,
compile_unit,
real_type,
real_ptr_type,
variables: HashMap::new(),
functions: HashMap::new(),
fn_value_opt: None,
tensor_ptr_opt: None,
layout,
diffsl_real_type,
int_type,
int_ptr_type,
globals,
threaded,
ptr_size_bits,
real_size_bits,
int_size_bits,
_ee: None,
};
if threaded {
ret.compile_barrier_init()?;
ret.compile_barrier()?;
ret.compile_barrier_grad()?;
// todo: think I can remove this unless I want to call enzyme using a llvm pass
//ret.globals.add_registered_barrier(ret.context, &ret.module);
}
Ok(ret)
}
fn start_function(
&mut self,
function: FunctionValue<'ctx>,
_code: Option<&str>,
) -> BasicBlock<'ctx> {
let basic_block = self.context.append_basic_block(function, "entry");
self.fn_value_opt = Some(function);
self.builder.position_at_end(basic_block);
if let Some(dibuilder) = &self.dibuilder {
let scope = self
.fn_value_opt
.unwrap()
.get_subprogram()
.unwrap()
.as_debug_info_scope();
let loc = dibuilder.create_debug_location(self.context, 0, 0, scope, None);
self.builder.set_current_debug_location(loc);
}
basic_block
}
fn add_function(
&mut self,
name: &str,
arg_names: &[&str],
arg_types: &[BasicMetadataTypeEnum<'ctx>],
linkage: Option<Linkage>,
is_real_return: bool,
) -> FunctionValue<'ctx> {
let function_type = if is_real_return {
self.real_type.fn_type(arg_types, false)
} else {
self.context.void_type().fn_type(arg_types, false)
};
let function = self.module.add_function(name, function_type, linkage);
if let (Some(dibuilder), Some(compile_unit)) = (&self.dibuilder, &self.compile_unit) {
let ditypes = arg_names
.iter()
.zip(arg_types.iter())
.map(|(&name, &ty)| {
let size_in_bits = if ty.is_float_type() {
self.real_size_bits
} else if ty.is_int_type() {
self.int_size_bits
} else if ty.is_pointer_type() {
self.ptr_size_bits
} else {
unreachable!("Unsupported argument type for debug info")
};
dibuilder
.create_basic_type(
name,
size_in_bits,
0x00,
<DIFlags as DIFlagsConstants>::PUBLIC,
)
.unwrap()
.as_type()
})
.collect::<Vec<_>>();
let subroutine_type = dibuilder.create_subroutine_type(
compile_unit.get_file(),
None,
&ditypes,
<DIFlags as DIFlagsConstants>::PUBLIC,
);
let func_scope = dibuilder.create_function(
compile_unit.as_debug_info_scope(),
name,
None,
compile_unit.get_file(),
0,
subroutine_type,
true,
true,
0,
<DIFlags as DIFlagsConstants>::PUBLIC,
true,
);
function.set_subprogram(func_scope);
}
self.functions.insert(name.to_owned(), function);
function
}
#[allow(dead_code)]
fn compile_print_value(
&mut self,
name: &str,
value: PrintValue<'ctx>,
) -> Result<CallSiteValue<'_>> {
// get printf function or declare it if it doesn't exist
let printf = match self.module.get_function("printf") {
Some(f) => f,
// int printf(const char *format, ...)
None => self.add_function(
"printf",
&["format"],
&[self.int_ptr_type.into()],
Some(Linkage::External),
false,
),
};
let (format_str, format_str_name) = match value {
PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
};
// change format_str to c string
let format_str = CString::new(format_str).unwrap();
// if format_str_name doesn not already exist as a global, add it
let format_str_global = match self.module.get_global(format_str_name.as_str()) {
Some(g) => g,
None => {
let format_str = self.context.const_string(format_str.as_bytes(), true);
let fmt_str =
self.module
.add_global(format_str.get_type(), None, format_str_name.as_str());
fmt_str.set_initializer(&format_str);
fmt_str.set_visibility(GlobalVisibility::Hidden);
fmt_str
}
};
// call printf with the format string and the value
let format_str_ptr = self.builder.build_pointer_cast(
format_str_global.as_pointer_value(),
self.int_ptr_type,
"format_str_ptr",
)?;
let value: BasicMetadataValueEnum = match value {
PrintValue::Real(v) => v.into(),
PrintValue::Int(v) => v.into(),
};
self.builder
.build_call(printf, &[format_str_ptr.into(), value], "printf_call")
.map_err(|e| anyhow!("Error building call to printf: {}", e))
}
fn compile_set_constants(
&mut self,
model: &DiscreteModel,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let fn_arg_names = &["thread_id", "thread_dim"];
let function = self.add_function(
"set_constants",
fn_arg_names,
&[self.int_type.into(), self.int_type.into()],
None,
false,
);
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_indices();
self.insert_constants(model);
let mut nbarriers = 0;
let total_barriers = (model.constant_defns().len()) as u64;
let total_barriers_val = self.int_type.const_int(total_barriers, false);
#[allow(clippy::explicit_counter_loop)]
for a in model.constant_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
self.module.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
self.clear();
let function = self.add_function("barrier_init", &[], &[], None, false);
let _entry_block = self.start_function(function, None);
let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
self.builder
.build_store(thread_counter, self.int_type.const_zero())?;
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
self.clear();
let function = self.add_function(
"barrier",
&["barrier_num", "total_barriers", "thread_count"],
&[
self.int_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
function.add_attribute(AttributeLoc::Function, noinline);
let _entry_block = self.start_function(function, None);
let increment_block = self.context.append_basic_block(function, "increment");
let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
let thread_count = function.get_nth_param(2).unwrap().into_int_value();
let nbarrier_equals_total_barriers = self
.builder
.build_int_compare(
IntPredicate::EQ,
barrier_num,
total_barriers,
"nbarrier_equals_total_barriers",
)
.unwrap();
// branch to barrier_done if nbarrier == total_barriers
self.builder.build_conditional_branch(
nbarrier_equals_total_barriers,
barrier_done_block,
increment_block,
)?;
self.builder.position_at_end(increment_block);
let barrier_num_times_thread_count = self
.builder
.build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
.unwrap();
// Atomically increment the barrier counter
let i32_type = self.context.i32_type();
let one = i32_type.const_int(1, false);
self.builder.build_atomicrmw(
AtomicRMWBinOp::Add,
thread_counter,
one,
AtomicOrdering::Monotonic,
)?;
// wait_loop:
self.builder.build_unconditional_branch(wait_loop_block)?;
self.builder.position_at_end(wait_loop_block);
let current_value = self
.builder
.build_load(i32_type, thread_counter, "current_value")?
.into_int_value();
current_value
.as_instruction_value()
.unwrap()
.set_atomic_ordering(AtomicOrdering::Monotonic)
.map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
let all_threads_done = self.builder.build_int_compare(
IntPredicate::UGE,
current_value,
barrier_num_times_thread_count,
"all_threads_done",
)?;
self.builder.build_conditional_branch(
all_threads_done,
barrier_done_block,
wait_loop_block,
)?;
self.builder.position_at_end(barrier_done_block);
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
self.clear();
let function = self.add_function(
"barrier_grad",
&["barrier_num", "total_barriers", "thread_count"],
&[
self.int_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
let _entry_block = self.start_function(function, None);
let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
let thread_count = function.get_nth_param(2).unwrap().into_int_value();
let twice_total_barriers = self
.builder
.build_int_mul(
total_barriers,
self.int_type.const_int(2, false),
"twice_total_barriers",
)
.unwrap();
let twice_total_barriers_minus_barrier_num = self
.builder
.build_int_sub(
twice_total_barriers,
barrier_num,
"twice_total_barriers_minus_barrier_num",
)
.unwrap();
let twice_total_barriers_minus_barrier_num_times_thread_count = self
.builder
.build_int_mul(
twice_total_barriers_minus_barrier_num,
thread_count,
"twice_total_barriers_minus_barrier_num_times_thread_count",
)
.unwrap();
// Atomically increment the barrier counter
let i32_type = self.context.i32_type();
let one = i32_type.const_int(1, false);
self.builder.build_atomicrmw(
AtomicRMWBinOp::Add,
thread_counter,
one,
AtomicOrdering::Monotonic,
)?;
// wait_loop:
self.builder.build_unconditional_branch(wait_loop_block)?;
self.builder.position_at_end(wait_loop_block);
let current_value = self
.builder
.build_load(i32_type, thread_counter, "current_value")?
.into_int_value();
current_value
.as_instruction_value()
.unwrap()
.set_atomic_ordering(AtomicOrdering::Monotonic)
.map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
let all_threads_done = self.builder.build_int_compare(
IntPredicate::UGE,
current_value,
twice_total_barriers_minus_barrier_num_times_thread_count,
"all_threads_done",
)?;
self.builder.build_conditional_branch(
all_threads_done,
barrier_done_block,
wait_loop_block,
)?;
self.builder.position_at_end(barrier_done_block);
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
fn jit_compile_call_barrier(
&mut self,
barrier_num: IntValue<'ctx>,
total_barriers: IntValue<'ctx>,
) {
if !self.threaded {
return;
}
let thread_dim = self.get_param("thread_dim");
let thread_dim = self
.builder
.build_load(self.int_type, *thread_dim, "thread_dim")
.unwrap()
.into_int_value();
let barrier = self.get_function("barrier").unwrap();
self.builder
.build_call(
barrier,
&[
BasicMetadataValueEnum::IntValue(barrier_num),
BasicMetadataValueEnum::IntValue(total_barriers),
BasicMetadataValueEnum::IntValue(thread_dim),
],
"barrier",
)
.unwrap();
}
fn jit_threading_limits(
&mut self,
size: IntValue<'ctx>,
) -> Result<(
IntValue<'ctx>,
IntValue<'ctx>,
BasicBlock<'ctx>,
BasicBlock<'ctx>,
)> {
let one = self.int_type.const_int(1, false);
let thread_id = self.get_param("thread_id");
let thread_id = self
.builder
.build_load(self.int_type, *thread_id, "thread_id")
.unwrap()
.into_int_value();
let thread_dim = self.get_param("thread_dim");
let thread_dim = self
.builder
.build_load(self.int_type, *thread_dim, "thread_dim")
.unwrap()
.into_int_value();
// start index is i * size / thread_dim
let i_times_size = self
.builder
.build_int_mul(thread_id, size, "i_times_size")?;
let start = self
.builder
.build_int_unsigned_div(i_times_size, thread_dim, "start")?;
// the ending index for thread i is (i+1) * size / thread_dim
let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
let i_plus_one_times_size =
self.builder
.build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
let end = self
.builder
.build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
let test_done = self.builder.get_insert_block().unwrap();
let next_block = self
.context
.append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
self.builder.position_at_end(next_block);
Ok((start, end, test_done, next_block))
}
fn jit_end_threading(
&mut self,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
test_done: BasicBlock<'ctx>,
next: BasicBlock<'ctx>,
) -> Result<()> {
let exit = self
.context
.append_basic_block(self.fn_value_opt.unwrap(), "exit");
self.builder.build_unconditional_branch(exit)?;
self.builder.position_at_end(test_done);
// done if start == end
let done = self
.builder
.build_int_compare(IntPredicate::EQ, start, end, "done")?;
self.builder.build_conditional_branch(done, exit, next)?;
self.builder.position_at_end(exit);
Ok(())
}
pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
self.module.write_bitcode_to_path(path);
}
fn insert_constants(&mut self, model: &DiscreteModel) {
if let Some(constants) = self.globals.constants.as_ref() {
self.insert_param("constants", constants.as_pointer_value());
for tensor in model.constant_defns() {
self.insert_tensor(tensor, true);
}
}
}
fn insert_data(&mut self, model: &DiscreteModel) {
self.insert_model_index();
self.insert_constants(model);
if let Some(input) = model.input() {
self.insert_tensor(input, false);
}
for tensor in model.input_dep_defns() {
self.insert_tensor(tensor, false);
}
for tensor in model.time_dep_defns() {
self.insert_tensor(tensor, false);
}
for tensor in model.state_dep_defns() {
self.insert_tensor(tensor, false);
}
for tensor in model.state_dep_post_f_defns() {
self.insert_tensor(tensor, false);
}
}
fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
context.ptr_type(AddressSpace::default())
}
fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
context.ptr_type(AddressSpace::default())
}
fn insert_indices(&mut self) {
if let Some(indices) = self.globals.indices.as_ref() {
let i32_type = self.context.i32_type();
let zero = i32_type.const_int(0, false);
let ptr = unsafe {
indices
.as_pointer_value()
.const_in_bounds_gep(i32_type, &[zero])
};
self.variables.insert("indices".to_owned(), ptr);
}
}
fn insert_model_index(&mut self) {
self.insert_param("model_index", self.globals.model_index.as_pointer_value());
}
fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
self.variables.insert(name.to_owned(), value);
}
fn build_gep<T: BasicType<'ctx>>(
&self,
ty: T,
ptr: PointerValue<'ctx>,
ordered_indexes: &[IntValue<'ctx>],
name: &str,
) -> Result<PointerValue<'ctx>> {
unsafe {
self.builder
.build_gep(ty, ptr, ordered_indexes, name)
.map_err(|e| e.into())
}
}
fn build_load<T: BasicType<'ctx>>(
&self,
ty: T,
ptr: PointerValue<'ctx>,
name: &str,
) -> Result<BasicValueEnum<'ctx>> {
self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
}
fn get_ptr_to_index<T: BasicType<'ctx>>(
builder: &Builder<'ctx>,
ty: T,
ptr: &PointerValue<'ctx>,
index: IntValue<'ctx>,
name: &str,
) -> PointerValue<'ctx> {
unsafe {
builder
.build_in_bounds_gep(ty, *ptr, &[index], name)
.unwrap()
}
}
fn insert_state(&mut self, u: &Tensor) {
let mut data_index = 0;
for blk in u.elmts() {
if let Some(name) = blk.name() {
let ptr = self.variables.get("u").unwrap();
let i = self
.context
.i32_type()
.const_int(data_index.try_into().unwrap(), false);
let alloca = Self::get_ptr_to_index(
&self.create_entry_block_builder(),
self.real_type,
ptr,
i,
blk.name().unwrap(),
);
self.variables.insert(name.to_owned(), alloca);
}
data_index += blk.nnz();
}
}
fn insert_dot_state(&mut self, dudt: &Tensor) {
let mut data_index = 0;
for blk in dudt.elmts() {
if let Some(name) = blk.name() {
let ptr = self.variables.get("dudt").unwrap();
let i = self
.context
.i32_type()
.const_int(data_index.try_into().unwrap(), false);
let alloca = Self::get_ptr_to_index(
&self.create_entry_block_builder(),
self.real_type,
ptr,
i,
blk.name().unwrap(),
);
self.variables.insert(name.to_owned(), alloca);
}
data_index += blk.nnz();
}
}
fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
let var_name = if is_constant { "constants" } else { "data" };
let ptr = *self.variables.get(var_name).unwrap();
let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
let i = self
.context
.i32_type()
.const_int(data_index.try_into().unwrap(), false);
let alloca = Self::get_ptr_to_index(
&self.create_entry_block_builder(),
self.real_type,
&ptr,
i,
tensor.name(),
);
self.variables.insert(tensor.name().to_owned(), alloca);
//insert any named blocks
for blk in tensor.elmts() {
if let Some(name) = blk.name() {
let i = self
.context
.i32_type()
.const_int(data_index.try_into().unwrap(), false);
let alloca = Self::get_ptr_to_index(
&self.create_entry_block_builder(),
self.real_type,
&ptr,
i,
name,
);
self.variables.insert(name.to_owned(), alloca);
}
// named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index
data_index += blk.nnz();
}
}
fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
self.variables.get(name).unwrap()
}
fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
self.variables.get(tensor.name()).unwrap()
}
fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
// Check cache for function
if let Some(&func) = self.functions.get(name) {
return Some(func);
}
let current_block = self.builder.get_insert_block().unwrap();
let current_loc = self.builder.get_current_debug_location().unwrap();
let current_function = self.fn_value_opt.unwrap();
let function = match name {
// support some llvm intrinsics
"sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" | "copysign"
| "pow" | "min" | "max" => {
let intrinsic_name = match name {
"min" => "minnum",
"max" => "maxnum",
"abs" => "fabs",
_ => name,
};
let llvm_name =
format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
let args_types: Vec<BasicTypeEnum> = vec![self.real_type.into()];
self.builder.unset_current_debug_location();
// Try intrinsic first, fall back to libm for all functions
Some(match Intrinsic::find(&llvm_name) {
Some(intrinsic) => intrinsic.get_declaration(&self.module, &args_types)?,
None => {
// Fallback: declare external libm function
let args_types_meta: Vec<BasicMetadataTypeEnum> =
vec![self.real_type.into()];
let function_type = self.real_type.fn_type(&args_types_meta, false);
self.module
.add_function(name, function_type, Some(Linkage::External))
}
})
}
// some custom functions
"sigmoid" => {
let arg_len = 1;
let ret_type = self.real_type;
let args_types = std::iter::repeat_n(ret_type, arg_len)
.map(|f| f.into())
.collect::<Vec<BasicMetadataTypeEnum>>();
let fn_val = self.add_function(name, &["x"], &args_types, None, true);
for arg in fn_val.get_param_iter() {
arg.into_float_value().set_name("x");
}
let _basic_block = self.start_function(fn_val, None);
let x = fn_val.get_nth_param(0)?.into_float_value();
let one = self.real_type.const_float(1.0);
let negx = self.builder.build_float_neg(x, name).ok()?;
let exp = self.get_function("exp").unwrap();
let exp_negx = self
.builder
.build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
.ok()?;
let one_plus_exp_negx = self
.builder
.build_float_add(
exp_negx
.try_as_basic_value()
.unwrap_basic()
.into_float_value(),
one,
name,
)
.ok()?;
let sigmoid = self
.builder
.build_float_div(one, one_plus_exp_negx, name)
.ok()?;
self.builder.build_return(Some(&sigmoid)).ok();
Some(fn_val)
}
"arcsinh" | "arccosh" => {
let arg_len = 1;
let ret_type = self.real_type;
let args_types = std::iter::repeat_n(ret_type, arg_len)
.map(|f| f.into())
.collect::<Vec<BasicMetadataTypeEnum>>();
let fn_val = self.add_function(name, &["x"], &args_types, None, true);
for arg in fn_val.get_param_iter() {
arg.into_float_value().set_name("x");
}
let _basic_block = self.start_function(fn_val, None);
let x = fn_val.get_nth_param(0)?.into_float_value();
let one = match name {
"arccosh" => self.real_type.const_float(-1.0),
"arcsinh" => self.real_type.const_float(1.0),
_ => panic!("unknown function"),
};
let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
let one_plus_x_squared = self.builder.build_float_add(x_squared, one, name).ok()?;
let sqrt = self.get_function("sqrt").unwrap();
let sqrt_one_plus_x_squared = self
.builder
.build_call(
sqrt,
&[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
name,
)
.unwrap()
.try_as_basic_value()
.unwrap_basic()
.into_float_value();
let x_plus_sqrt_one_plus_x_squared = self
.builder
.build_float_add(x, sqrt_one_plus_x_squared, name)
.ok()?;
let ln = self.get_function("log").unwrap();
let result = self
.builder
.build_call(
ln,
&[BasicMetadataValueEnum::FloatValue(
x_plus_sqrt_one_plus_x_squared,
)],
name,
)
.unwrap()
.try_as_basic_value()
.unwrap_basic()
.into_float_value();
self.builder.build_return(Some(&result)).ok();
Some(fn_val)
}
"heaviside" => {
let arg_len = 1;
let ret_type = self.real_type;
let args_types = std::iter::repeat_n(ret_type, arg_len)
.map(|f| f.into())
.collect::<Vec<BasicMetadataTypeEnum>>();
let fn_val = self.add_function(name, &["x"], &args_types, None, true);
for arg in fn_val.get_param_iter() {
arg.into_float_value().set_name("x");
}
let _basic_block = self.start_function(fn_val, None);
let x = fn_val.get_nth_param(0)?.into_float_value();
let zero = self.real_type.const_float(0.0);
let one = self.real_type.const_float(1.0);
let result = self
.builder
.build_select(
self.builder
.build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
.unwrap(),
one,
zero,
name,
)
.ok()?;
self.builder.build_return(Some(&result)).ok();
Some(fn_val)
}
"tanh" | "sinh" | "cosh" => {
let arg_len = 1;
let ret_type = self.real_type;
let args_types = std::iter::repeat_n(ret_type, arg_len)
.map(|f| f.into())
.collect::<Vec<BasicMetadataTypeEnum>>();
let fn_val = self.add_function(name, &["x"], &args_types, None, true);
for arg in fn_val.get_param_iter() {
arg.into_float_value().set_name("x");
}
let _basic_block = self.start_function(fn_val, None);
let x = fn_val.get_nth_param(0)?.into_float_value();
let negx = self.builder.build_float_neg(x, name).ok()?;
let exp = self.get_function("exp").unwrap();
let exp_negx = self
.builder
.build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
.ok()?;
let expx = self
.builder
.build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
.ok()?;
let expx_minus_exp_negx = self
.builder
.build_float_sub(
expx.try_as_basic_value().unwrap_basic().into_float_value(),
exp_negx
.try_as_basic_value()
.unwrap_basic()
.into_float_value(),
name,
)
.ok()?;
let expx_plus_exp_negx = self
.builder
.build_float_add(
expx.try_as_basic_value().unwrap_basic().into_float_value(),
exp_negx
.try_as_basic_value()
.unwrap_basic()
.into_float_value(),
name,
)
.ok()?;
let result = match name {
"tanh" => self
.builder
.build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
.ok()?,
"sinh" => self
.builder
.build_float_div(expx_minus_exp_negx, self.real_type.const_float(2.0), name)
.ok()?,
"cosh" => self
.builder
.build_float_div(expx_plus_exp_negx, self.real_type.const_float(2.0), name)
.ok()?,
_ => panic!("unknown function"),
};
self.builder.build_return(Some(&result)).ok();
Some(fn_val)
}
_ => None,
}?;
if !function.verify(true) {
function.print_to_stderr();
panic!("Invalid generated function for {}", name);
}
self.builder.position_at_end(current_block);
if self.dibuilder.is_some() {
self.builder.set_current_debug_location(current_loc);
}
self.fn_value_opt = Some(current_function);
self.functions.insert(name.to_owned(), function);
Some(function)
}
/// Returns the `FunctionValue` representing the function being compiled.
#[inline]
fn fn_value(&self) -> FunctionValue<'ctx> {
self.fn_value_opt.unwrap()
}
#[inline]
fn tensor_ptr(&self) -> PointerValue<'ctx> {
self.tensor_ptr_opt.unwrap()
}
/// Creates a new builder in the entry block of the function.
fn create_entry_block_builder(&self) -> Builder<'ctx> {
let builder = self.context.create_builder();
let entry = self.fn_value().get_first_basic_block().unwrap();
match entry.get_first_instruction() {
Some(first_instr) => builder.position_before(&first_instr),
None => builder.position_at_end(entry),
}
builder
}
fn jit_compile_scalar(
&mut self,
a: &Tensor,
res_ptr_opt: Option<PointerValue<'ctx>>,
) -> Result<PointerValue<'ctx>> {
let res_type = self.real_type;
let res_ptr = match res_ptr_opt {
Some(ptr) => ptr,
None => self
.create_entry_block_builder()
.build_alloca(res_type, a.name())?,
};
let name = a.name();
let elmt = a.elmts().first().unwrap();
// if threaded then only the first thread will evaluate the scalar
let curr_block = self.builder.get_insert_block().unwrap();
let mut next_block_opt = None;
if self.threaded {
let next_block = self.context.append_basic_block(self.fn_value(), "next");
self.builder.position_at_end(next_block);
next_block_opt = Some(next_block);
}
let zero = self.int_type.const_zero();
let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
self.builder.build_store(res_ptr, float_value)?;
// complete the threading block
if self.threaded {
let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
self.builder.build_unconditional_branch(exit_block)?;
self.builder.position_at_end(curr_block);
let thread_id = self.get_param("thread_id");
let thread_id = self
.builder
.build_load(self.int_type, *thread_id, "thread_id")
.unwrap()
.into_int_value();
let is_first_thread = self.builder.build_int_compare(
IntPredicate::EQ,
thread_id,
self.int_type.const_zero(),
"is_first_thread",
)?;
self.builder.build_conditional_branch(
is_first_thread,
next_block_opt.unwrap(),
exit_block,
)?;
self.builder.position_at_end(exit_block);
}
Ok(res_ptr)
}
fn jit_compile_tensor(
&mut self,
a: &Tensor,
res_ptr_opt: Option<PointerValue<'ctx>>,
code: Option<&str>,
) -> Result<PointerValue<'ctx>> {
// treat scalar as a special case
if a.rank() == 0 {
return self.jit_compile_scalar(a, res_ptr_opt);
}
let res_type = self.real_type;
let res_ptr = match res_ptr_opt {
Some(ptr) => ptr,
None => self
.create_entry_block_builder()
.build_alloca(res_type, a.name())?,
};
// set up the tensor storage pointer and index into this data
self.tensor_ptr_opt = Some(res_ptr);
for (i, blk) in a.elmts().iter().enumerate() {
let default = format!("{}-{}", a.name(), i);
let name = blk.name().unwrap_or(default.as_str());
self.jit_compile_block(name, a, blk, code)?;
}
Ok(res_ptr)
}
fn jit_compile_block(
&mut self,
name: &str,
tensor: &Tensor,
elmt: &TensorBlock,
code: Option<&str>,
) -> Result<()> {
let translation = Translation::new(
elmt.expr_layout(),
elmt.layout(),
elmt.start(),
tensor.layout_ptr(),
);
if let Some(dibuilder) = &self.dibuilder {
let (line_no, column_no) = match (elmt.expr().span, code) {
(Some(s), Some(c)) => {
let s = Span::new(c, s.pos_start, s.pos_end).unwrap();
s.start_pos().line_col()
}
_ => (0, 0),
};
let scope = self
.fn_value_opt
.unwrap()
.get_subprogram()
.unwrap()
.as_debug_info_scope();
let loc = dibuilder.create_debug_location(
self.context,
line_no.try_into().unwrap(),
column_no.try_into().unwrap(),
scope,
None,
);
self.builder.set_current_debug_location(loc);
}
if elmt.expr_layout().is_dense() {
self.jit_compile_dense_block(name, elmt, &translation)
} else if elmt.expr_layout().is_diagonal() {
self.jit_compile_diagonal_block(name, elmt, &translation)
} else if elmt.expr_layout().is_sparse() {
match translation.source {
TranslationFrom::SparseContraction { .. } => {
self.jit_compile_sparse_contraction_block(name, elmt, &translation)
}
_ => self.jit_compile_sparse_block(name, elmt, &translation),
}
} else {
Err(anyhow!(
"unsupported block layout: {:?}",
elmt.expr_layout()
))
}
}
// for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index
fn jit_compile_dense_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
) -> Result<()> {
let int_type = self.int_type;
let mut preblock = self.builder.get_insert_block().unwrap();
let expr_rank = elmt.expr_layout().rank();
let expr_shape = elmt
.expr_layout()
.shape()
.mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
let one = int_type.const_int(1, false);
let mut expr_strides = vec![1; expr_rank];
if expr_rank > 0 {
for i in (0..expr_rank - 1).rev() {
expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
}
}
let expr_strides = expr_strides
.iter()
.map(|&s| int_type.const_int(s.try_into().unwrap(), false))
.collect::<Vec<IntValue>>();
// setup indices, loop through the nested loops
let mut indices = Vec::new();
let mut blocks = Vec::new();
// allocate the contract sum if needed
let (contract_sum, contract_by, contract_strides) =
if let TranslationFrom::DenseContraction {
contract_by,
contract_len: _,
} = translation.source
{
let contract_rank = expr_rank - contract_by;
let mut contract_strides = vec![1; contract_rank];
for i in (0..contract_rank - 1).rev() {
contract_strides[i] =
contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
}
let contract_strides = contract_strides
.iter()
.map(|&s| int_type.const_int(s.try_into().unwrap(), false))
.collect::<Vec<IntValue>>();
(
Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
contract_by,
Some(contract_strides),
)
} else {
(None, 0, None)
};
// we will thread the output loop, except if we are contracting to a scalar
let (thread_start, thread_end, test_done, next) = if self.threaded {
let (start, end, test_done, next) =
self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
preblock = next;
(Some(start), Some(end), Some(test_done), Some(next))
} else {
(None, None, None, None)
};
for i in 0..expr_rank {
let block = self.context.append_basic_block(self.fn_value(), name);
self.builder.build_unconditional_branch(block)?;
self.builder.position_at_end(block);
let start_index = if i == 0 && self.threaded {
thread_start.unwrap()
} else {
self.int_type.const_zero()
};
let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
curr_index.add_incoming(&[(&start_index, preblock)]);
if i == expr_rank - contract_by - 1 {
if let Some(contract_sum) = contract_sum {
self.builder
.build_store(contract_sum, self.real_type.const_zero())?;
}
}
indices.push(curr_index);
blocks.push(block);
preblock = block;
}
let indices_int: Vec<IntValue> = indices
.iter()
.map(|i| i.as_basic_value().into_int_value())
.collect();
// if indices = (i, j, k) and shape = (a, b, c) calculate expr_index = (k + j*b + i*b*c)
let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
let mut stride = 1u64;
if !indices.is_empty() {
for i in (0..indices.len() - 1).rev() {
let iname_i = indices_int[i];
let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
stride *= shapei;
let stride_intval = self.context.i32_type().const_int(stride, false);
let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
}
}
let float_value =
self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
if let Some(contract_sum) = contract_sum {
let contract_sum_value = self
.build_load(self.real_type, contract_sum, "contract_sum")?
.into_float_value();
let new_contract_sum_value = self.builder.build_float_add(
contract_sum_value,
float_value,
"new_contract_sum",
)?;
self.builder
.build_store(contract_sum, new_contract_sum_value)?;
} else {
let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
self.int_type.const_zero(),
|acc, (i, s)| {
let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
self.builder.build_int_add(acc, tmp, "acc").unwrap()
},
);
self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
}
let mut postblock = self.builder.get_insert_block().unwrap();
// unwind the nested loops
for i in (0..expr_rank).rev() {
// increment index
let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
indices[i].add_incoming(&[(&next_index, postblock)]);
if i == expr_rank - contract_by - 1 {
if let Some(contract_sum) = contract_sum {
let contract_sum_value = self
.build_load(self.real_type, contract_sum, "contract_sum")?
.into_float_value();
let contract_strides = contract_strides.as_ref().unwrap();
let elmt_index = indices_int
.iter()
.take(contract_strides.len())
.zip(contract_strides.iter())
.fold(self.int_type.const_zero(), |acc, (i, s)| {
let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
self.builder.build_int_add(acc, tmp, "acc").unwrap()
});
self.jit_compile_store(
name,
elmt,
elmt_index,
contract_sum_value,
translation,
)?;
}
}
let end_index = if i == 0 && self.threaded {
thread_end.unwrap()
} else {
expr_shape[i]
};
// loop condition
let loop_while =
self.builder
.build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
let block = self.context.append_basic_block(self.fn_value(), name);
self.builder
.build_conditional_branch(loop_while, blocks[i], block)?;
self.builder.position_at_end(block);
postblock = block;
}
if self.threaded {
self.jit_end_threading(
thread_start.unwrap(),
thread_end.unwrap(),
test_done.unwrap(),
next.unwrap(),
)?;
}
Ok(())
}
fn jit_compile_sparse_contraction_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
) -> Result<()> {
match translation.source {
TranslationFrom::SparseContraction { .. } => {}
_ => {
panic!("expected sparse contraction")
}
}
let int_type = self.int_type;
let translation_index = self
.layout
.get_translation_index(elmt.expr_layout(), elmt.layout())
.unwrap();
let translation_index = translation_index + translation.get_from_index_in_data_layout();
let final_contract_index =
int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
let (thread_start, thread_end, test_done, next) = if self.threaded {
let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
(Some(start), Some(end), Some(test_done), Some(next))
} else {
(None, None, None, None)
};
let preblock = self.builder.get_insert_block().unwrap();
let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
// loop through each contraction
let block = self.context.append_basic_block(self.fn_value(), name);
self.builder.build_unconditional_branch(block)?;
self.builder.position_at_end(block);
let contract_index = self.builder.build_phi(int_type, "i")?;
let contract_start = if self.threaded {
thread_start.unwrap()
} else {
int_type.const_zero()
};
contract_index.add_incoming(&[(&contract_start, preblock)]);
let start_index = self.builder.build_int_add(
int_type.const_int(translation_index.try_into().unwrap(), false),
self.builder.build_int_mul(
int_type.const_int(2, false),
contract_index.as_basic_value().into_int_value(),
name,
)?,
name,
)?;
let end_index =
self.builder
.build_int_add(start_index, int_type.const_int(1, false), name)?;
let start_ptr = self.build_gep(
self.int_type,
*self.get_param("indices"),
&[start_index],
"start_index_ptr",
)?;
let start_contract = self
.build_load(self.int_type, start_ptr, "start")?
.into_int_value();
let end_ptr = self.build_gep(
self.int_type,
*self.get_param("indices"),
&[end_index],
"end_index_ptr",
)?;
let end_contract = self
.build_load(self.int_type, end_ptr, "end")?
.into_int_value();
// initialise the contract sum
self.builder
.build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
// loop through each element in the contraction
let start_contract_block = self
.context
.append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
self.builder
.build_unconditional_branch(start_contract_block)?;
self.builder.position_at_end(start_contract_block);
let expr_index_phi = self.builder.build_phi(int_type, "j")?;
expr_index_phi.add_incoming(&[(&start_contract, block)]);
let expr_index = expr_index_phi.as_basic_value().into_int_value();
let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
// loop body - eval expression and increment sum
let float_value =
self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
let contract_sum_value = self
.build_load(self.real_type, contract_sum_ptr, "contract_sum")?
.into_float_value();
let new_contract_sum_value =
self.builder
.build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
self.builder
.build_store(contract_sum_ptr, new_contract_sum_value)?;
let end_contract_block = self.builder.get_insert_block().unwrap();
// increment contract loop index
let next_elmt_index =
self.builder
.build_int_add(expr_index, int_type.const_int(1, false), name)?;
expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
// contract loop condition
let loop_while = self.builder.build_int_compare(
IntPredicate::ULT,
next_elmt_index,
end_contract,
name,
)?;
let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
self.builder.build_conditional_branch(
loop_while,
start_contract_block,
post_contract_block,
)?;
self.builder.position_at_end(post_contract_block);
// store the result
self.jit_compile_store(
name,
elmt,
contract_index.as_basic_value().into_int_value(),
new_contract_sum_value,
translation,
)?;
// increment outer loop index
let next_contract_index = self.builder.build_int_add(
contract_index.as_basic_value().into_int_value(),
int_type.const_int(1, false),
name,
)?;
contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
// outer loop condition
let loop_while = self.builder.build_int_compare(
IntPredicate::ULT,
next_contract_index,
thread_end.unwrap_or(final_contract_index),
name,
)?;
let post_block = self.context.append_basic_block(self.fn_value(), name);
self.builder
.build_conditional_branch(loop_while, block, post_block)?;
self.builder.position_at_end(post_block);
if self.threaded {
self.jit_end_threading(
thread_start.unwrap(),
thread_end.unwrap(),
test_done.unwrap(),
next.unwrap(),
)?;
}
Ok(())
}
fn expr_indices_from_elmt_index(
&mut self,
elmt_index: IntValue<'ctx>,
elmt: &TensorBlock,
name: &str,
) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
let int_type = self.int_type;
// loop body - load index from layout
let elmt_index_mult_rank = self.builder.build_int_mul(
elmt_index,
int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
name,
)?;
(0..elmt.expr_layout().rank())
.map(|i| {
let layout_index_plus_offset =
int_type.const_int((layout_index + i).try_into().unwrap(), false);
let curr_index = self.builder.build_int_add(
elmt_index_mult_rank,
layout_index_plus_offset,
name,
)?;
let ptr = Self::get_ptr_to_index(
&self.builder,
self.int_type,
self.get_param("indices"),
curr_index,
name,
);
Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
})
.collect::<Result<Vec<_>, anyhow::Error>>()
}
// for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index
// TODO: havn't implemented contractions yet
fn jit_compile_sparse_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
) -> Result<()> {
let int_type = self.int_type;
let start_index = int_type.const_int(0, false);
let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
let (thread_start, thread_end, test_done, next) = if self.threaded {
let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
(Some(start), Some(end), Some(test_done), Some(next))
} else {
(None, None, None, None)
};
// loop through the non-zero elements
let preblock = self.builder.get_insert_block().unwrap();
let loop_block = self.context.append_basic_block(self.fn_value(), name);
self.builder.build_unconditional_branch(loop_block)?;
self.builder.position_at_end(loop_block);
let curr_index = self.builder.build_phi(int_type, "i")?;
curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
let elmt_index = curr_index.as_basic_value().into_int_value();
let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
// loop body - eval expression
let float_value =
self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
// jit_compile_expr or jit_compile_broadcast_and_store may have changed the current block
let end_loop_block = self.builder.get_insert_block().unwrap();
// increment loop index
let one = int_type.const_int(1, false);
let next_index = self.builder.build_int_add(elmt_index, one, name)?;
curr_index.add_incoming(&[(&next_index, end_loop_block)]);
// loop condition
let loop_while = self.builder.build_int_compare(
IntPredicate::ULT,
next_index,
thread_end.unwrap_or(end_index),
name,
)?;
let post_block = self.context.append_basic_block(self.fn_value(), name);
self.builder
.build_conditional_branch(loop_while, loop_block, post_block)?;
self.builder.position_at_end(post_block);
if self.threaded {
self.jit_end_threading(
thread_start.unwrap(),
thread_end.unwrap(),
test_done.unwrap(),
next.unwrap(),
)?;
}
Ok(())
}
// for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index
fn jit_compile_diagonal_block(
&mut self,
name: &str,
elmt: &TensorBlock,
translation: &Translation,
) -> Result<()> {
let int_type = self.int_type;
let start_index = int_type.const_int(0, false);
let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
let (thread_start, thread_end, test_done, next) = if self.threaded {
let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
(Some(start), Some(end), Some(test_done), Some(next))
} else {
(None, None, None, None)
};
// loop through the non-zero elements
let preblock = self.builder.get_insert_block().unwrap();
let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
self.builder.build_unconditional_branch(start_loop_block)?;
self.builder.position_at_end(start_loop_block);
let curr_index = self.builder.build_phi(int_type, "i")?;
curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
// loop body - index is just the same for each element
let elmt_index = curr_index.as_basic_value().into_int_value();
let indices_int: Vec<IntValue> =
(0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
// loop body - eval expression
let float_value =
self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
// loop body - store result
self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
let end_loop_block = self.builder.get_insert_block().unwrap();
// increment loop index
let one = int_type.const_int(1, false);
let next_index = self.builder.build_int_add(elmt_index, one, name)?;
curr_index.add_incoming(&[(&next_index, end_loop_block)]);
// loop condition
let loop_while = self.builder.build_int_compare(
IntPredicate::ULT,
next_index,
thread_end.unwrap_or(end_index),
name,
)?;
let post_block = self.context.append_basic_block(self.fn_value(), name);
self.builder
.build_conditional_branch(loop_while, start_loop_block, post_block)?;
self.builder.position_at_end(post_block);
if self.threaded {
self.jit_end_threading(
thread_start.unwrap(),
thread_end.unwrap(),
test_done.unwrap(),
next.unwrap(),
)?;
}
Ok(())
}
fn jit_compile_broadcast_and_store(
&mut self,
name: &str,
elmt: &TensorBlock,
float_value: FloatValue<'ctx>,
expr_index: IntValue<'ctx>,
translation: &Translation,
) -> Result<()> {
let int_type = self.int_type;
let one = int_type.const_int(1, false);
let zero = int_type.const_int(0, false);
let pre_block = self.builder.get_insert_block().unwrap();
match translation.source {
TranslationFrom::Broadcast {
broadcast_by: _,
broadcast_len,
} => {
let bcast_start_index = zero;
let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
// setup loop block
let bcast_block = self.context.append_basic_block(self.fn_value(), name);
self.builder.build_unconditional_branch(bcast_block)?;
self.builder.position_at_end(bcast_block);
let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
// store value
let store_index = self.builder.build_int_add(
self.builder
.build_int_mul(expr_index, bcast_end_index, "store_index")?,
bcast_index.as_basic_value().into_int_value(),
"bcast_store_index",
)?;
self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
// increment index
let bcast_next_index = self.builder.build_int_add(
bcast_index.as_basic_value().into_int_value(),
one,
name,
)?;
bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
// loop condition
let bcast_cond = self.builder.build_int_compare(
IntPredicate::ULT,
bcast_next_index,
bcast_end_index,
"broadcast_cond",
)?;
let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
self.builder
.build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
self.builder.position_at_end(post_bcast_block);
Ok(())
}
TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
Ok(())
}
_ => Err(anyhow!("Invalid translation")),
}
}
fn jit_compile_store(
&mut self,
name: &str,
elmt: &TensorBlock,
store_index: IntValue<'ctx>,
float_value: FloatValue<'ctx>,
translation: &Translation,
) -> Result<()> {
let int_type = self.int_type;
let res_index = match &translation.target {
TranslationTo::Contiguous { start, end: _ } => {
let start_const = int_type.const_int((*start).try_into().unwrap(), false);
self.builder.build_int_add(start_const, store_index, name)?
}
TranslationTo::Sparse { indices: _ } => {
// load store index from layout
let translate_index = self
.layout
.get_translation_index(elmt.expr_layout(), elmt.layout())
.unwrap();
let translate_store_index =
translate_index + translation.get_to_index_in_data_layout();
let translate_store_index =
int_type.const_int(translate_store_index.try_into().unwrap(), false);
let elmt_index_strided = store_index;
let curr_index =
self.builder
.build_int_add(elmt_index_strided, translate_store_index, name)?;
let ptr = Self::get_ptr_to_index(
&self.builder,
self.int_type,
self.get_param("indices"),
curr_index,
name,
);
self.build_load(self.int_type, ptr, name)?.into_int_value()
}
};
let resi_ptr = Self::get_ptr_to_index(
&self.builder,
self.real_type,
&self.tensor_ptr(),
res_index,
name,
);
self.builder.build_store(resi_ptr, float_value)?;
Ok(())
}
fn jit_compile_expr(
&mut self,
name: &str,
expr: &Ast,
index: &[IntValue<'ctx>],
elmt: &TensorBlock,
expr_index: IntValue<'ctx>,
) -> Result<FloatValue<'ctx>> {
let name = elmt.name().unwrap_or(name);
match &expr.kind {
AstKind::Binop(binop) => {
let lhs =
self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
let rhs =
self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
match binop.op {
'*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
'/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
'-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
'+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
}
}
AstKind::Monop(monop) => {
let child =
self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
match monop.op {
'-' => Ok(self.builder.build_float_neg(child, name)?),
unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
}
}
AstKind::Call(call) => match self.get_function(call.fn_name) {
Some(function) => {
let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
for arg in call.args.iter() {
let arg_val =
self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
args.push(BasicMetadataValueEnum::FloatValue(arg_val));
}
let ret_value = self
.builder
.build_call(function, args.as_slice(), name)?
.try_as_basic_value()
.unwrap_basic()
.into_float_value();
Ok(ret_value)
}
None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
},
AstKind::CallArg(arg) => {
self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
}
AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
AstKind::Name(iname) => {
if iname.name == "N" {
if iname.is_tangent {
return Ok(self.real_type.const_float(0.0));
}
let model_index = self
.build_load(self.int_type, *self.get_param("model_index"), "model_index")?
.into_int_value();
let n_value = self.builder.build_signed_int_to_float(
model_index,
self.real_type,
"n_as_real",
)?;
return Ok(n_value);
}
let ptr = *self.get_param(iname.name);
let layout = self.layout.get_layout(iname.name).unwrap().clone();
let iname_elmt_index = if layout.is_dense() {
// permute indices based on the index chars of this tensor
let mut no_transform = true;
let mut iname_index = Vec::new();
for (i, c) in iname.indices.iter().enumerate() {
// find the position index of this index char in the tensor's index chars,
// if it's not found then it must be a contraction index so is at the end
let pi = elmt
.indices()
.iter()
.position(|x| x == c)
.unwrap_or(elmt.indices().len());
// if we are indexing, add the start indice to index[pi]
if let Some(indice_ast) = iname.indice.as_ref() {
let Some(indice) = indice_ast.kind.as_indice() else {
return Err(anyhow!("invalid index expression '{}'", indice_ast));
};
let start_intval =
self.jit_compile_integer_expr(indice.first.as_ref(), name)?;
// if we are indexing a single element, the index may be out of bounds
let index_pi = if pi >= index.len() {
self.context.i32_type().const_int(0, false)
} else {
index[pi]
};
let index_pi =
self.builder.build_int_add(index_pi, start_intval, name)?;
iname_index.push(index_pi);
} else {
let index_pi = if pi >= index.len() {
self.context.i32_type().const_int(0, false)
} else {
index[pi]
};
iname_index.push(index_pi);
}
no_transform = no_transform && pi == i;
}
// broadcasting, if the shape is 1 then the index is always 0
let iname_index: Vec<IntValue> = iname_index
.iter()
.enumerate()
.map(|(i, &idx)| {
if layout.shape()[i] == 1 {
self.context.i32_type().const_int(0, false)
} else {
idx
}
})
.collect();
// calculate the element index using iname_index and the shape of the tensor
// TODO: can we optimise this by using expr_index, and also including elmt_index?
if !iname_index.is_empty() {
// if iname_index is (a, b, c) and shape is s calculate iname_elmt_index = ( c * s[2] + b * s[2] * s[1] + a * s[2] * s[1] * s[0] )
let mut iname_elmt_index = *iname_index.last().unwrap();
let mut stride = 1u64;
for i in (0..iname_index.len() - 1).rev() {
let iname_i = iname_index[i];
let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
stride *= shapei;
let stride_intval = self.context.i32_type().const_int(stride, false);
let stride_mul_i =
self.builder.build_int_mul(stride_intval, iname_i, name)?;
iname_elmt_index =
self.builder
.build_int_add(iname_elmt_index, stride_mul_i, name)?;
}
iname_elmt_index
} else {
// zero if we are not indexing, otherwise use the start value of indice
let zero = self.context.i32_type().const_int(0, false);
zero
}
} else if layout.is_sparse() || layout.is_diagonal() {
let expr_layout = elmt.expr_layout();
if expr_layout != &layout {
// get correct index from binary layout map, ie. indices[ binary_layout_index + expr_index ]
// if its a -1 then return a 0
// ie. expr_index = binary_layout[expr_index]
//. if expr_index == -1 then return 0 as the value of the expression
//. otherwise load the value at that index
// we are doing an if statement so I think we need to return early here
let permutation =
DataLayout::permutation(elmt, iname.indices.as_slice(), &layout);
if let Some(base_binary_layout_index) =
self.layout
.get_binary_layout_index(&layout, expr_layout, permutation)
{
let binary_layout_index = self.builder.build_int_add(
self.int_type
.const_int(base_binary_layout_index.try_into().unwrap(), false),
expr_index,
name,
)?;
let indices_ptr = Self::get_ptr_to_index(
&self.builder,
self.int_type,
self.get_param("indices"),
binary_layout_index,
name,
);
let mapped_index = self
.build_load(self.int_type, indices_ptr, name)?
.into_int_value();
let is_less_than_zero = self.builder.build_int_compare(
IntPredicate::SLT,
mapped_index,
self.int_type.const_int(0, true),
"sparse_index_check",
)?;
let is_less_than_zero_block =
self.context.append_basic_block(self.fn_value(), "lt_zero");
let not_less_than_zero_block = self
.context
.append_basic_block(self.fn_value(), "not_lt_zero");
let merge_block =
self.context.append_basic_block(self.fn_value(), "merge");
self.builder.build_conditional_branch(
is_less_than_zero,
is_less_than_zero_block,
not_less_than_zero_block,
)?;
// if mapped index < 0 return 0
self.builder.position_at_end(is_less_than_zero_block);
let zero_value = self.real_type.const_float(0.);
self.builder.build_unconditional_branch(merge_block)?;
// if mapped index >=0 load value at that index
self.builder.position_at_end(not_less_than_zero_block);
let value_ptr = Self::get_ptr_to_index(
&self.builder,
self.real_type,
&ptr,
mapped_index,
name,
);
let value = self
.build_load(self.real_type, value_ptr, name)?
.into_float_value();
self.builder.build_unconditional_branch(merge_block)?;
// return value or 0 from if statement
self.builder.position_at_end(merge_block);
let if_return_value =
self.builder.build_phi(self.real_type, "sparse_value")?;
if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
let phi_value = if_return_value.as_basic_value().into_float_value();
return Ok(phi_value);
} else {
expr_index
}
} else {
// we can just use the elmt_index since the layouts are the same
expr_index
}
} else {
panic!("unexpected layout");
};
let value_ptr = Self::get_ptr_to_index(
&self.builder,
self.real_type,
&ptr,
iname_elmt_index,
name,
);
Ok(self
.build_load(self.real_type, value_ptr, name)?
.into_float_value())
}
AstKind::NamedGradient(name) => {
let name_str = name.to_string();
let ptr = self.get_param(name_str.as_str());
Ok(self
.build_load(self.real_type, *ptr, name_str.as_str())?
.into_float_value())
}
AstKind::Index(_) => todo!(),
AstKind::Slice(_) => todo!(),
AstKind::Integer(_) => todo!(),
_ => panic!("unexprected astkind"),
}
}
fn jit_compile_integer_expr(&mut self, expr: &Ast, name: &str) -> Result<IntValue<'ctx>> {
match &expr.kind {
AstKind::Integer(value) => Ok(self.int_type.const_int(*value as u64, true)),
AstKind::Number(value) => {
if value.fract() != 0.0 {
return Err(anyhow!(
"non-integer value '{}' in integer expression",
value
));
}
Ok(self.int_type.const_int(*value as u64, true))
}
AstKind::Name(iname) => {
if iname.name == "N" {
Ok(self
.build_load(self.int_type, *self.get_param("model_index"), name)?
.into_int_value())
} else {
Err(anyhow!(
"unsupported name '{}' in integer expression",
iname.name
))
}
}
AstKind::Monop(monop) => {
let child = self.jit_compile_integer_expr(monop.child.as_ref(), name)?;
match monop.op {
'+' => Ok(child),
'-' => self.builder.build_int_neg(child, name).map_err(Into::into),
_ => Err(anyhow!("unknown integer unary op '{}'", monop.op)),
}
}
AstKind::Binop(binop) => {
let lhs = self.jit_compile_integer_expr(binop.left.as_ref(), name)?;
let rhs = self.jit_compile_integer_expr(binop.right.as_ref(), name)?;
match binop.op {
'+' => self
.builder
.build_int_add(lhs, rhs, name)
.map_err(Into::into),
'-' => self
.builder
.build_int_sub(lhs, rhs, name)
.map_err(Into::into),
'*' => self
.builder
.build_int_mul(lhs, rhs, name)
.map_err(Into::into),
'/' => self
.builder
.build_int_signed_div(lhs, rhs, name)
.map_err(Into::into),
'%' => self
.builder
.build_int_signed_rem(lhs, rhs, name)
.map_err(Into::into),
_ => Err(anyhow!("unknown integer binary op '{}'", binop.op)),
}
}
_ => Err(anyhow!("unsupported integer expression '{}'", expr)),
}
}
fn clear(&mut self) {
self.variables.clear();
//self.functions.clear();
self.fn_value_opt = None;
self.tensor_ptr_opt = None;
}
fn build_dep_call(
&mut self,
dep_fn: FunctionValue<'ctx>,
call_name: &str,
barrier_start: u64,
total_barriers: u64,
) -> Result<()> {
let t = self
.build_load(self.real_type, *self.get_param("t"), "t")?
.into_float_value();
let u = *self.get_param("u");
let data = *self.get_param("data");
let thread_id = self
.build_load(self.int_type, *self.get_param("thread_id"), "thread_id")?
.into_int_value();
let thread_dim = self
.build_load(self.int_type, *self.get_param("thread_dim"), "thread_dim")?
.into_int_value();
let barrier_start = self.int_type.const_int(barrier_start, false);
let total_barriers = self.int_type.const_int(total_barriers, false);
self.builder.build_call(
dep_fn,
&[
t.into(),
u.into(),
data.into(),
thread_id.into(),
thread_dim.into(),
barrier_start.into(),
total_barriers.into(),
],
call_name,
)?;
Ok(())
}
fn ensure_time_dep_fn<'m>(
&mut self,
model: &'m DiscreteModel,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
if let Some(function) = self.module.get_function("calc_time_dep") {
return Ok(function);
}
self.compile_dep_defns(model, "calc_time_dep", model.time_dep_defns(), code)
}
fn ensure_state_dep_fn<'m>(
&mut self,
model: &'m DiscreteModel,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
if let Some(function) = self.module.get_function("calc_state_dep") {
return Ok(function);
}
self.compile_dep_defns(model, "calc_state_dep", model.state_dep_defns(), code)
}
fn ensure_state_dep_post_f_fn<'m>(
&mut self,
model: &'m DiscreteModel,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
if let Some(function) = self.module.get_function("calc_state_dep_post_f") {
return Ok(function);
}
self.compile_dep_defns(
model,
"calc_state_dep_post_f",
model.state_dep_post_f_defns(),
code,
)
}
fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
match arg {
BasicValueEnum::PointerValue(v) => v,
BasicValueEnum::FloatValue(v) => {
let alloca = self
.create_entry_block_builder()
.build_alloca(arg.get_type(), name)
.unwrap();
self.builder.build_store(alloca, v).unwrap();
alloca
}
BasicValueEnum::IntValue(v) => {
let alloca = self
.create_entry_block_builder()
.build_alloca(arg.get_type(), name)
.unwrap();
self.builder.build_store(alloca, v).unwrap();
alloca
}
_ => unreachable!(),
}
}
pub fn compile_set_u0<'m>(
&mut self,
model: &'m DiscreteModel,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
let function = self.add_function(
"set_u0",
fn_arg_names,
&[
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[0, 1] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_data(model);
self.insert_indices();
let mut nbarriers = 0;
let total_barriers = (model.input_dep_defns().len() + 1) as u64;
let total_barriers_val = self.int_type.const_int(total_barriers, false);
#[allow(clippy::explicit_counter_loop)]
for a in model.input_dep_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_calc_out<'m>(
&mut self,
model: &'m DiscreteModel,
include_constants: bool,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
self.clear();
let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
let function_name = if include_constants {
"calc_out_full"
} else {
"calc_out"
};
let function = self.add_function(
function_name,
fn_arg_names,
&[
self.real_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[1, 2, 3] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_state(model.state());
self.insert_data(model);
self.insert_indices();
// print thread_id and thread_dim
//let thread_id = function.get_nth_param(3).unwrap();
//let thread_dim = function.get_nth_param(4).unwrap();
//self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?;
//self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?;
if let Some(out) = model.out() {
let mut nbarriers = 0;
let mut total_barriers = (model.time_dep_defns().len()
+ model.state_dep_defns().len()
+ model.state_dep_post_f_defns().len()
+ 1) as u64;
if include_constants {
total_barriers += model.input_dep_defns().len() as u64;
}
let total_barriers_val = self.int_type.const_int(total_barriers, false);
if include_constants {
// calculate time independant definitions
for tensor in model.input_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
}
// calculate time dependant definitions
if !model.time_dep_defns().is_empty() {
self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
nbarriers += model.time_dep_defns().len() as u64;
}
// calculate state dependant definitions
if !model.state_dep_defns().is_empty() {
self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_defns().len() as u64;
}
if !model.state_dep_post_f_defns().is_empty() {
self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_post_f_defns().len() as u64;
}
self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
fn compile_dep_defns<'m>(
&mut self,
model: &'m DiscreteModel,
fn_name: &str,
tensors: &[Tensor<'m>],
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let fn_arg_names = &[
"t",
"u",
"data",
"thread_id",
"thread_dim",
"barrier_start",
"total_barriers",
];
let function = self.add_function(
fn_name,
fn_arg_names,
&[
self.real_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[1, 2] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_state(model.state());
self.insert_data(model);
self.insert_indices();
let barrier_start = self
.build_load(
self.int_type,
*self.get_param("barrier_start"),
"barrier_start",
)?
.into_int_value();
let total_barriers = self
.build_load(
self.int_type,
*self.get_param("total_barriers"),
"total_barriers",
)?
.into_int_value();
let one = self.int_type.const_int(1, false);
for (index, tensor) in tensors.iter().enumerate() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let index = self.int_type.const_int(index as u64, false);
let barrier_num_base =
self.builder
.build_int_add(barrier_start, index, "barrier_num_base")?;
let barrier_num = self
.builder
.build_int_add(barrier_num_base, one, "barrier_num")?;
self.jit_compile_call_barrier(barrier_num, total_barriers);
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_calc_stop<'m>(
&mut self,
model: &'m DiscreteModel,
include_constants: bool,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
self.clear();
let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
let function_name = if include_constants {
"calc_stop_full"
} else {
"calc_stop"
};
let function = self.add_function(
function_name,
fn_arg_names,
&[
self.real_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[1, 2, 3] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_state(model.state());
self.insert_data(model);
self.insert_indices();
if let Some(stop) = model.stop() {
// calculate time dependant definitions
let mut nbarriers = 0;
let mut total_barriers = (model.time_dep_defns().len()
+ model.state_dep_defns().len()
+ model.state_dep_post_f_defns().len()
+ 1) as u64;
if include_constants {
total_barriers += model.input_dep_defns().len() as u64;
}
let total_barriers_val = self.int_type.const_int(total_barriers, false);
if include_constants {
// calculate time independent definitions
for tensor in model.input_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
}
if !model.time_dep_defns().is_empty() {
self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
nbarriers += model.time_dep_defns().len() as u64;
}
// calculate state dependant definitions
if !model.state_dep_defns().is_empty() {
self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_defns().len() as u64;
}
if !model.state_dep_post_f_defns().is_empty() {
self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_post_f_defns().len() as u64;
}
let res_ptr = self.get_param("root");
self.jit_compile_tensor(stop, Some(*res_ptr), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_reset<'m>(
&mut self,
model: &'m DiscreteModel,
include_constants: bool,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
self.clear();
let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"];
let function_name = if include_constants {
"reset_full"
} else {
"reset"
};
let function = self.add_function(
function_name,
fn_arg_names,
&[
self.real_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[1, 2, 3] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_state(model.state());
self.insert_data(model);
self.insert_indices();
if let Some(reset) = model.reset() {
let mut nbarriers = 0;
let mut total_barriers = (model.time_dep_defns().len()
+ model.state_dep_defns().len()
+ model.state_dep_post_f_defns().len()
+ 1) as u64;
if include_constants {
total_barriers += model.input_dep_defns().len() as u64;
}
let total_barriers_val = self.int_type.const_int(total_barriers, false);
if include_constants {
// calculate time independent definitions
for tensor in model.input_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
}
if !model.time_dep_defns().is_empty() {
self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
nbarriers += model.time_dep_defns().len() as u64;
}
if !model.state_dep_defns().is_empty() {
self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_defns().len() as u64;
}
if !model.state_dep_post_f_defns().is_empty() {
self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_post_f_defns().len() as u64;
}
let res_ptr = self.get_param("reset");
self.jit_compile_tensor(reset, Some(*res_ptr), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_rhs<'m>(
&mut self,
model: &'m DiscreteModel,
include_constants: bool,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
self.clear();
let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
let function_name = if include_constants { "rhs_full" } else { "rhs" };
let function = self.add_function(
function_name,
fn_arg_names,
&[
self.real_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[1, 2, 3] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_state(model.state());
self.insert_data(model);
self.insert_indices();
let mut nbarriers = 0;
let mut total_barriers =
(model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
if include_constants {
total_barriers += model.input_dep_defns().len() as u64;
// calculate constant definitions
for tensor in model.input_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
let total_barriers_val = self.int_type.const_int(total_barriers, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
}
// calculate time dependant definitions
if !model.time_dep_defns().is_empty() {
self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
nbarriers += model.time_dep_defns().len() as u64;
}
if !model.state_dep_defns().is_empty() {
self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
nbarriers += model.state_dep_defns().len() as u64;
}
// F
let res_ptr = self.get_param("rr");
self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?;
let total_barriers_val = self.int_type.const_int(total_barriers, false);
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_mass<'m>(
&mut self,
model: &'m DiscreteModel,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
let function = self.add_function(
"mass",
fn_arg_names,
&[
self.real_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
self.int_type.into(),
],
None,
false,
);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in &[1, 2, 3] {
function.add_attribute(AttributeLoc::Param(*i), noalign);
}
let _basic_block = self.start_function(function, code);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
// only put code in this function if we have a state_dot and lhs
if model.state_dot().is_some() && model.lhs().is_some() {
let state_dot = model.state_dot().unwrap();
let lhs = model.lhs().unwrap();
self.insert_dot_state(state_dot);
self.insert_data(model);
self.insert_indices();
// calculate time dependant definitions
let mut nbarriers = 0;
let total_barriers =
(model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
let total_barriers_val = self.int_type.const_int(total_barriers, false);
for tensor in model.time_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
for a in model.dstate_dep_defns() {
self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
// mass
let res_ptr = self.get_param("rr");
self.jit_compile_tensor(lhs, Some(*res_ptr), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_gradient(
&mut self,
original_function: FunctionValue<'ctx>,
args_type: &[CompileGradientArgType],
mode: CompileMode,
fn_name: &str,
) -> Result<FunctionValue<'ctx>> {
self.clear();
// construct the gradient function
let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
let mut start_param_index: Vec<u32> = Vec::new();
let mut ptr_arg_indices: Vec<u32> = Vec::new();
for (i, arg) in original_function.get_param_iter().enumerate() {
let start_index = u32::try_from(fn_type.len()).unwrap();
start_param_index.push(start_index);
let arg_type = arg.get_type();
fn_type.push(arg_type.into());
// constant args with type T in the original funciton have 2 args of type [int, T]
enzyme_fn_type.push(self.int_type.into());
enzyme_fn_type.push(arg.get_type().into());
if arg_type.is_pointer_type() {
ptr_arg_indices.push(start_index);
}
match args_type[i] {
CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
fn_type.push(arg.get_type().into());
enzyme_fn_type.push(arg.get_type().into());
if arg_type.is_pointer_type() {
ptr_arg_indices.push(start_index + 1);
}
}
CompileGradientArgType::Const => {}
}
}
let fn_arg_names = fn_type
.iter()
.enumerate()
.map(|(i, _)| format!("arg{}", i))
.collect::<Vec<String>>();
let fn_arg_names_ref = fn_arg_names
.iter()
.map(|s| s.as_str())
.collect::<Vec<&str>>();
let function = self.add_function(fn_name, &fn_arg_names_ref, &fn_type, None, false);
// add noalias
let alias_id = Attribute::get_named_enum_kind_id("noalias");
let noalign = self.context.create_enum_attribute(alias_id, 0);
for i in ptr_arg_indices {
function.add_attribute(AttributeLoc::Param(i), noalign);
}
let _basic_block = self.start_function(function, None);
let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
let mut input_activity = Vec::new();
let mut arg_trees = Vec::new();
for (i, arg) in original_function.get_param_iter().enumerate() {
let param_index = start_param_index[i];
let fn_arg = function.get_nth_param(param_index).unwrap();
// we'll probably only get double or pointers to doubles, so let assume this for now
// todo: perhaps refactor this into a recursive function, might be overkill
let concrete_type = match arg.get_type() {
BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
RealType::F32 => CConcreteType_DT_Float,
RealType::F64 => CConcreteType_DT_Double,
},
BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
_ => panic!("unsupported type"),
};
let new_tree = unsafe {
EnzymeNewTypeTreeCT(
concrete_type,
self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
)
};
unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
// pointer to real type
if concrete_type == CConcreteType_DT_Pointer {
let inner_concrete_type = match self.diffsl_real_type {
RealType::F32 => CConcreteType_DT_Float,
RealType::F64 => CConcreteType_DT_Double,
};
let inner_new_tree = unsafe {
EnzymeNewTypeTreeCT(
inner_concrete_type,
self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
)
};
unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
}
arg_trees.push(new_tree);
match args_type[i] {
CompileGradientArgType::Dup => {
// pass in the arg value
enzyme_fn_args.push(fn_arg.into());
// pass in the darg value
let fn_darg = function.get_nth_param(param_index + 1).unwrap();
enzyme_fn_args.push(fn_darg.into());
input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
}
CompileGradientArgType::DupNoNeed => {
// pass in the arg value
enzyme_fn_args.push(fn_arg.into());
// pass in the darg value
let fn_darg = function.get_nth_param(param_index + 1).unwrap();
enzyme_fn_args.push(fn_darg.into());
input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
}
CompileGradientArgType::Const => {
// pass in the arg value
enzyme_fn_args.push(fn_arg.into());
input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
}
}
}
// if we have void ret, this must be false;
let ret_primary_ret = false;
let diff_ret = false;
let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
let ret_tree = unsafe {
EnzymeNewTypeTreeCT(
CConcreteType_DT_Anything,
self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
)
};
// always optimize
let fnc_opt_base = true;
let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
let kv_tmp = IntList {
data: std::ptr::null_mut(),
size: 0,
};
let mut known_values = vec![kv_tmp; input_activity.len()];
let fn_type_info = CFnTypeInfo {
Arguments: arg_trees.as_mut_ptr(),
Return: ret_tree,
KnownValues: known_values.as_mut_ptr(),
};
let type_analysis: EnzymeTypeAnalysisRef =
unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
let mut args_uncacheable = vec![0; arg_trees.len()];
let enzyme_function = match mode {
CompileMode::Forward | CompileMode::ForwardSens => unsafe {
EnzymeCreateForwardDiff(
logic_ref, // Logic
std::ptr::null_mut(),
std::ptr::null_mut(),
original_function.as_value_ref(),
ret_activity, // LLVM function, return type
input_activity.as_mut_ptr(),
input_activity.len(), // constant arguments
type_analysis, // type analysis struct
ret_primary_ret as u8,
CDerivativeMode_DEM_ForwardMode, // return value, dret_used, top_level which was 1
1, // free memory
0, // runtime activity
0, // strong zero
1, // vector mode width
std::ptr::null_mut(), // additional argument
fn_type_info, // additional_arg, type info (return + args)
1, // subsequent calls may write
args_uncacheable.as_mut_ptr(), // overwritten args
args_uncacheable.len(), // overwritten args length
std::ptr::null_mut(), // write augmented function to this
)
},
CompileMode::Reverse | CompileMode::ReverseSens => {
let mut call_enzyme = || unsafe {
EnzymeCreatePrimalAndGradient(
logic_ref,
std::ptr::null_mut(),
std::ptr::null_mut(),
original_function.as_value_ref(),
ret_activity,
input_activity.as_mut_ptr(),
input_activity.len(),
type_analysis,
ret_primary_ret as u8,
diff_ret as u8,
CDerivativeMode_DEM_ReverseModeCombined,
0,
0, // strong zero
1,
1,
std::ptr::null_mut(),
0, // force annonymous tape
fn_type_info,
0, // subsequent calls may write
args_uncacheable.as_mut_ptr(),
args_uncacheable.len(),
std::ptr::null_mut(),
if self.threaded { 1 } else { 0 }, // atomic add
)
};
if self.threaded {
// the register call handler alters a global variable, so we need to lock it
let _lock = my_mutex.lock().unwrap();
let barrier_string = CString::new("barrier").unwrap();
unsafe {
EnzymeRegisterCallHandler(
barrier_string.as_ptr(),
Some(fwd_handler),
Some(rev_handler),
)
};
let ret = call_enzyme();
// unregister it so some other thread doesn't use it
unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
ret
} else {
call_enzyme()
}
}
};
// free everything
unsafe { FreeEnzymeLogic(logic_ref) };
unsafe { FreeTypeAnalysis(type_analysis) };
unsafe { EnzymeFreeTypeTree(ret_tree) };
for tree in arg_trees {
unsafe { EnzymeFreeTypeTree(tree) };
}
// call enzyme function
let enzyme_function =
unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
self.builder
.build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
// return
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
enzyme_function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
self.clear();
let fn_arg_names = &[
"states",
"inputs",
"outputs",
"data",
"stop",
"has_mass",
"has_reset",
];
let function = self.add_function(
"get_dims",
fn_arg_names,
&[
self.int_ptr_type.into(),
self.int_ptr_type.into(),
self.int_ptr_type.into(),
self.int_ptr_type.into(),
self.int_ptr_type.into(),
self.int_ptr_type.into(),
self.int_ptr_type.into(),
],
None,
false,
);
let _block = self.start_function(function, None);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_indices();
let number_of_states = model.state().nnz() as u64;
let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
let number_of_outputs = match model.out() {
Some(out) => out.nnz() as u64,
None => 0,
};
let number_of_stop = if let Some(stop) = model.stop() {
stop.nnz() as u64
} else {
0
};
let has_mass = match model.lhs().is_some() {
true => 1u64,
false => 0u64,
};
let has_reset = match model.reset().is_some() {
true => 1u64,
false => 0u64,
};
let data_len = self.layout.data().len() as u64;
self.builder.build_store(
*self.get_param("states"),
self.int_type.const_int(number_of_states, false),
)?;
self.builder.build_store(
*self.get_param("inputs"),
self.int_type.const_int(number_of_inputs, false),
)?;
self.builder.build_store(
*self.get_param("outputs"),
self.int_type.const_int(number_of_outputs, false),
)?;
self.builder.build_store(
*self.get_param("data"),
self.int_type.const_int(data_len, false),
)?;
self.builder.build_store(
*self.get_param("stop"),
self.int_type.const_int(number_of_stop, false),
)?;
self.builder.build_store(
*self.get_param("has_mass"),
self.int_type.const_int(has_mass, false),
)?;
self.builder.build_store(
*self.get_param("has_reset"),
self.int_type.const_int(has_reset, false),
)?;
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_get_tensor(
&mut self,
model: &DiscreteModel,
name: &str,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
let function_name = format!("get_tensor_{name}");
let fn_arg_names = &["data", "tensor_data", "tensor_size"];
let function = self.add_function(
function_name.as_str(),
fn_arg_names,
&[
self.real_ptr_type.into(),
real_ptr_ptr_type.into(),
self.int_ptr_type.into(),
],
None,
false,
);
let _basic_block = self.start_function(function, None);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_data(model);
let ptr = self.get_param(name);
let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
let tensor_size_value = self.int_type.const_int(tensor_size, false);
self.builder
.build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
self.builder
.build_store(*self.get_param("tensor_size"), tensor_size_value)?;
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_get_constant(
&mut self,
model: &DiscreteModel,
name: &str,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
let function_name = format!("get_constant_{name}");
let fn_arg_names = &["tensor_data", "tensor_size"];
let function = self.add_function(
function_name.as_str(),
fn_arg_names,
&[real_ptr_ptr_type.into(), self.int_ptr_type.into()],
None,
false,
);
let _basic_block = self.start_function(function, None);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
self.insert_constants(model);
let ptr = self.get_param(name);
let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
let tensor_size_value = self.int_type.const_int(tensor_size, false);
self.builder
.build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
self.builder
.build_store(*self.get_param("tensor_size"), tensor_size_value)?;
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_inputs(
&mut self,
model: &DiscreteModel,
is_get: bool,
) -> Result<FunctionValue<'ctx>> {
self.clear();
let function_name = if is_get { "get_inputs" } else { "set_inputs" };
let fn_arg_names: &[&str] = if is_get {
&["inputs", "data"]
} else {
&["inputs", "data", "model_index"]
};
let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = if is_get {
&[self.real_ptr_type.into(), self.real_ptr_type.into()]
} else {
&[
self.real_ptr_type.into(),
self.real_ptr_type.into(),
self.int_type.into(),
]
};
let function = self.add_function(function_name, fn_arg_names, fn_arg_types, None, false);
let block = self.start_function(function, None);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
if !is_get {
let model_index = self
.build_load(self.int_type, *self.get_param("model_index"), "model_index")?
.into_int_value();
self.builder
.build_store(self.globals.model_index.as_pointer_value(), model_index)?;
}
if let Some(input) = model.input() {
let name = input.name();
self.insert_tensor(input, false);
let ptr = self.get_var(input);
// loop thru the elements of this input and set/get them using the inputs ptr
let inputs_start_index = self.int_type.const_int(0, false);
let start_index = self.int_type.const_int(0, false);
let end_index = self
.int_type
.const_int(input.nnz().try_into().unwrap(), false);
let input_block = self.context.append_basic_block(function, name);
self.builder.build_unconditional_branch(input_block)?;
self.builder.position_at_end(input_block);
let index = self.builder.build_phi(self.int_type, "i")?;
index.add_incoming(&[(&start_index, block)]);
// loop body - copy value from inputs to data
let curr_input_index = index.as_basic_value().into_int_value();
let input_ptr =
Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
let curr_inputs_index =
self.builder
.build_int_add(inputs_start_index, curr_input_index, name)?;
let inputs_ptr = Self::get_ptr_to_index(
&self.builder,
self.real_type,
self.get_param("inputs"),
curr_inputs_index,
name,
);
if is_get {
let input_value = self
.build_load(self.real_type, input_ptr, name)?
.into_float_value();
self.builder.build_store(inputs_ptr, input_value)?;
} else {
let input_value = self
.build_load(self.real_type, inputs_ptr, name)?
.into_float_value();
self.builder.build_store(input_ptr, input_value)?;
}
// increment loop index
let one = self.int_type.const_int(1, false);
let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
index.add_incoming(&[(&next_index, input_block)]);
// loop condition
let loop_while =
self.builder
.build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
let post_block = self.context.append_basic_block(function, name);
self.builder
.build_conditional_branch(loop_while, input_block, post_block)?;
self.builder.position_at_end(post_block);
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
self.clear();
let fn_arg_names = &["id"];
let function = self.add_function(
"set_id",
fn_arg_names,
&[self.real_ptr_type.into()],
None,
false,
);
let mut block = self.start_function(function, None);
for (i, arg) in function.get_param_iter().enumerate() {
let name = fn_arg_names[i];
let alloca = self.function_arg_alloca(name, arg);
self.insert_param(name, alloca);
}
let mut id_index = 0usize;
for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
let name = blk.name().unwrap_or("unknown");
// loop thru the elements of this state blk and set the corresponding elements of id
let id_start_index = self.int_type.const_int(id_index as u64, false);
let blk_start_index = self.int_type.const_int(0, false);
let blk_end_index = self
.int_type
.const_int(blk.nnz().try_into().unwrap(), false);
let blk_block = self.context.append_basic_block(function, name);
self.builder.build_unconditional_branch(blk_block)?;
self.builder.position_at_end(blk_block);
let index = self.builder.build_phi(self.int_type, "i")?;
index.add_incoming(&[(&blk_start_index, block)]);
// loop body - copy value from inputs to data
let curr_blk_index = index.as_basic_value().into_int_value();
let curr_id_index = self
.builder
.build_int_add(id_start_index, curr_blk_index, name)?;
let id_ptr = Self::get_ptr_to_index(
&self.builder,
self.real_type,
self.get_param("id"),
curr_id_index,
name,
);
let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
self.builder.build_store(id_ptr, is_algebraic_value)?;
// increment loop index
let one = self.int_type.const_int(1, false);
let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
index.add_incoming(&[(&next_index, blk_block)]);
// loop condition
let loop_while = self.builder.build_int_compare(
IntPredicate::ULT,
next_index,
blk_end_index,
name,
)?;
let post_block = self.context.append_basic_block(function, name);
self.builder
.build_conditional_branch(loop_while, blk_block, post_block)?;
self.builder.position_at_end(post_block);
// get ready for next blk
block = post_block;
id_index += blk.nnz();
}
self.builder.build_return(None)?;
if function.verify(true) {
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}
pub fn module(&self) -> &Module<'ctx> {
&self.module
}
}