diffsl/execution/llvm/
codegen.rs

1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use inkwell::attributes::{Attribute, AttributeLoc};
4use inkwell::basic_block::BasicBlock;
5use inkwell::builder::Builder;
6use inkwell::context::{AsContextRef, Context};
7use inkwell::execution_engine::ExecutionEngine;
8use inkwell::intrinsics::Intrinsic;
9use inkwell::module::{Linkage, Module};
10use inkwell::passes::PassBuilderOptions;
11use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
12use inkwell::types::{
13    BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
14};
15use inkwell::values::{
16    AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
17    FunctionValue, GlobalValue, IntValue, PointerValue,
18};
19use inkwell::{
20    AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
21    OptimizationLevel,
22};
23use llvm_sys::core::{
24    LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
25    LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
26};
27use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
28use std::collections::HashMap;
29use std::ffi::CString;
30use std::iter::zip;
31use std::pin::Pin;
32use target_lexicon::Triple;
33
34use crate::ast::{Ast, AstKind};
35use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
36use crate::enzyme::{
37    CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
38    CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
39    CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
40    DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
41    EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
42    EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
43    FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
44    CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
45};
46use crate::execution::compiler::CompilerMode;
47use crate::execution::module::{
48    CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
49};
50use crate::execution::scalar::RealType;
51use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
52use lazy_static::lazy_static;
53use std::sync::Mutex;
54
55lazy_static! {
56    static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
57}
58
59struct ImmovableLlvmModule {
60    // actually has lifetime of `context`
61    // declared first so it's droped before `context`
62    codegen: Option<CodeGen<'static>>,
63    // safety: we must never move out of this box as long as codgen is alive
64    context: AliasableBox<Context>,
65    _pin: std::marker::PhantomPinned,
66}
67
68pub struct LlvmModule {
69    inner: Pin<Box<ImmovableLlvmModule>>,
70    machine: TargetMachine,
71}
72
73unsafe impl Send for LlvmModule {}
74unsafe impl Sync for LlvmModule {}
75
76impl LlvmModule {
77    fn new(
78        triple: Option<Triple>,
79        model: &DiscreteModel,
80        threaded: bool,
81        real_type: RealType,
82    ) -> Result<Self> {
83        let initialization_config = &InitializationConfig::default();
84        Target::initialize_all(initialization_config);
85        let host_triple = Triple::host();
86        let (triple_str, native) = match triple {
87            Some(ref triple) => (triple.to_string(), false),
88            None => (host_triple.to_string(), true),
89        };
90        let triple = TargetTriple::create(triple_str.as_str());
91        let target = Target::from_triple(&triple).unwrap();
92        let cpu = if native {
93            TargetMachine::get_host_cpu_name().to_string()
94        } else {
95            "generic".to_string()
96        };
97        let features = if native {
98            TargetMachine::get_host_cpu_features().to_string()
99        } else {
100            "".to_string()
101        };
102        let machine = target
103            .create_target_machine(
104                &triple,
105                cpu.as_str(),
106                features.as_str(),
107                inkwell::OptimizationLevel::Aggressive,
108                inkwell::targets::RelocMode::Default,
109                inkwell::targets::CodeModel::Default,
110            )
111            .unwrap();
112
113        let context = AliasableBox::from_unique(Box::new(Context::create()));
114        let mut pinned = Self {
115            inner: Box::pin(ImmovableLlvmModule {
116                codegen: None,
117                context,
118                _pin: std::marker::PhantomPinned,
119            }),
120            machine,
121        };
122
123        let context_ref = pinned.inner.context.as_ref();
124        let real_type_llvm = match real_type {
125            RealType::F32 => context_ref.f32_type(),
126            RealType::F64 => context_ref.f64_type(),
127        };
128        let codegen = CodeGen::new(
129            model,
130            context_ref,
131            real_type,
132            real_type_llvm,
133            context_ref.i32_type(),
134            threaded,
135        )?;
136        let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
137        unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
138        Ok(pinned)
139    }
140
141    fn pre_autodiff_optimisation(&mut self) -> Result<()> {
142        //let pass_manager_builder = PassManagerBuilder::create();
143        //pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default);
144        //let pass_manager = PassManager::create(());
145        //pass_manager_builder.populate_module_pass_manager(&pass_manager);
146        //pass_manager.run_on(self.codegen().module());
147
148        //self.codegen().module().print_to_stderr();
149        // optimise at -O2 no unrolling before giving to enzyme
150        let pass_options = PassBuilderOptions::create();
151        //pass_options.set_verify_each(true);
152        //pass_options.set_debug_logging(true);
153        //pass_options.set_loop_interleaving(true);
154        pass_options.set_loop_vectorization(false);
155        pass_options.set_loop_slp_vectorization(false);
156        pass_options.set_loop_unrolling(false);
157        //pass_options.set_forget_all_scev_in_loop_unroll(true);
158        //pass_options.set_licm_mssa_opt_cap(1);
159        //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10);
160        //pass_options.set_call_graph_profile(true);
161        //pass_options.set_merge_functions(true);
162
163        //let passes = "default<O2>";
164        let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
165        let (codegen, machine) = self.codegen_and_machine_mut();
166        codegen
167            .module()
168            .run_passes(passes, machine, pass_options)
169            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
170    }
171
172    fn post_autodiff_optimisation(&mut self) -> Result<()> {
173        // remove noinline attribute from barrier function as only needed for enzyme
174        if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
175            let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
176            barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
177        }
178
179        // remove all preprocess_* functions
180        for f in self.codegen_mut().module.get_functions() {
181            if f.get_name().to_str().unwrap().starts_with("preprocess_") {
182                unsafe { f.delete() };
183            }
184        }
185
186        //self.codegen()
187        //    .module()
188        //    .print_to_file("post_autodiff_optimisation.ll")
189        //    .unwrap();
190
191        let passes = "default<O3>";
192        let (codegen, machine) = self.codegen_and_machine_mut();
193        codegen
194            .module()
195            .run_passes(passes, machine, PassBuilderOptions::create())
196            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
197
198        Ok(())
199    }
200
201    pub fn print(&self) {
202        self.codegen().module().print_to_stderr();
203    }
204    fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
205        unsafe {
206            self.inner
207                .as_mut()
208                .get_unchecked_mut()
209                .codegen
210                .as_mut()
211                .unwrap()
212        }
213    }
214    fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
215        (
216            unsafe {
217                self.inner
218                    .as_mut()
219                    .get_unchecked_mut()
220                    .codegen
221                    .as_mut()
222                    .unwrap()
223            },
224            &self.machine,
225        )
226    }
227
228    fn codegen(&self) -> &CodeGen<'static> {
229        self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
230    }
231}
232
233impl CodegenModule for LlvmModule {}
234
235impl CodegenModuleCompile for LlvmModule {
236    fn from_discrete_model(
237        model: &DiscreteModel,
238        mode: CompilerMode,
239        triple: Option<Triple>,
240        real_type: RealType,
241    ) -> Result<Self> {
242        let thread_dim = mode.thread_dim(model.state().nnz());
243        let threaded = thread_dim > 1;
244        if (unsafe { LLVMIsMultithreaded() } <= 0) {
245            return Err(anyhow!(
246                "LLVM is not compiled with multithreading support, but this codegen module requires it."
247            ));
248        }
249
250        let mut module = Self::new(triple, model, threaded, real_type)?;
251
252        let set_u0 = module.codegen_mut().compile_set_u0(model)?;
253        let _calc_stop = module.codegen_mut().compile_calc_stop(model)?;
254        let rhs = module.codegen_mut().compile_rhs(model, false)?;
255        let rhs_full = module.codegen_mut().compile_rhs(model, true)?;
256        let mass = module.codegen_mut().compile_mass(model)?;
257        let calc_out = module.codegen_mut().compile_calc_out(model, false)?;
258        let calc_out_full = module.codegen_mut().compile_calc_out(model, true)?;
259        let _set_id = module.codegen_mut().compile_set_id(model)?;
260        let _get_dims = module.codegen_mut().compile_get_dims(model)?;
261        let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
262        let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
263        let _set_constants = module.codegen_mut().compile_set_constants(model)?;
264        let tensor_info = module
265            .codegen()
266            .layout
267            .tensors()
268            .map(|(name, is_constant)| (name.to_string(), is_constant))
269            .collect::<Vec<_>>();
270        for (tensor, is_constant) in tensor_info {
271            if is_constant {
272                module
273                    .codegen_mut()
274                    .compile_get_constant(model, tensor.as_str())?;
275            } else {
276                module
277                    .codegen_mut()
278                    .compile_get_tensor(model, tensor.as_str())?;
279            }
280        }
281
282        module.pre_autodiff_optimisation()?;
283
284        module.codegen_mut().compile_gradient(
285            set_u0,
286            &[
287                CompileGradientArgType::DupNoNeed,
288                CompileGradientArgType::DupNoNeed,
289                CompileGradientArgType::Const,
290                CompileGradientArgType::Const,
291            ],
292            CompileMode::Forward,
293            "set_u0_grad",
294        )?;
295
296        module.codegen_mut().compile_gradient(
297            rhs,
298            &[
299                CompileGradientArgType::Const,
300                CompileGradientArgType::DupNoNeed,
301                CompileGradientArgType::DupNoNeed,
302                CompileGradientArgType::DupNoNeed,
303                CompileGradientArgType::Const,
304                CompileGradientArgType::Const,
305            ],
306            CompileMode::Forward,
307            "rhs_grad",
308        )?;
309
310        module.codegen_mut().compile_gradient(
311            calc_out,
312            &[
313                CompileGradientArgType::Const,
314                CompileGradientArgType::DupNoNeed,
315                CompileGradientArgType::DupNoNeed,
316                CompileGradientArgType::DupNoNeed,
317                CompileGradientArgType::Const,
318                CompileGradientArgType::Const,
319            ],
320            CompileMode::Forward,
321            "calc_out_grad",
322        )?;
323        module.codegen_mut().compile_gradient(
324            set_inputs,
325            &[
326                CompileGradientArgType::DupNoNeed,
327                CompileGradientArgType::DupNoNeed,
328            ],
329            CompileMode::Forward,
330            "set_inputs_grad",
331        )?;
332
333        module.codegen_mut().compile_gradient(
334            set_u0,
335            &[
336                CompileGradientArgType::DupNoNeed,
337                CompileGradientArgType::DupNoNeed,
338                CompileGradientArgType::Const,
339                CompileGradientArgType::Const,
340            ],
341            CompileMode::Reverse,
342            "set_u0_rgrad",
343        )?;
344
345        module.codegen_mut().compile_gradient(
346            mass,
347            &[
348                CompileGradientArgType::Const,
349                CompileGradientArgType::DupNoNeed,
350                CompileGradientArgType::DupNoNeed,
351                CompileGradientArgType::DupNoNeed,
352                CompileGradientArgType::Const,
353                CompileGradientArgType::Const,
354            ],
355            CompileMode::Reverse,
356            "mass_rgrad",
357        )?;
358
359        module.codegen_mut().compile_gradient(
360            rhs,
361            &[
362                CompileGradientArgType::Const,
363                CompileGradientArgType::DupNoNeed,
364                CompileGradientArgType::DupNoNeed,
365                CompileGradientArgType::DupNoNeed,
366                CompileGradientArgType::Const,
367                CompileGradientArgType::Const,
368            ],
369            CompileMode::Reverse,
370            "rhs_rgrad",
371        )?;
372        module.codegen_mut().compile_gradient(
373            calc_out,
374            &[
375                CompileGradientArgType::Const,
376                CompileGradientArgType::DupNoNeed,
377                CompileGradientArgType::DupNoNeed,
378                CompileGradientArgType::DupNoNeed,
379                CompileGradientArgType::Const,
380                CompileGradientArgType::Const,
381            ],
382            CompileMode::Reverse,
383            "calc_out_rgrad",
384        )?;
385
386        module.codegen_mut().compile_gradient(
387            set_inputs,
388            &[
389                CompileGradientArgType::DupNoNeed,
390                CompileGradientArgType::DupNoNeed,
391            ],
392            CompileMode::Reverse,
393            "set_inputs_rgrad",
394        )?;
395
396        module.codegen_mut().compile_gradient(
397            rhs_full,
398            &[
399                CompileGradientArgType::Const,
400                CompileGradientArgType::Const,
401                CompileGradientArgType::DupNoNeed,
402                CompileGradientArgType::DupNoNeed,
403                CompileGradientArgType::Const,
404                CompileGradientArgType::Const,
405            ],
406            CompileMode::ForwardSens,
407            "rhs_sgrad",
408        )?;
409
410        module.codegen_mut().compile_gradient(
411            set_u0,
412            &[
413                CompileGradientArgType::DupNoNeed,
414                CompileGradientArgType::DupNoNeed,
415                CompileGradientArgType::Const,
416                CompileGradientArgType::Const,
417            ],
418            CompileMode::ForwardSens,
419            "set_u0_sgrad",
420        )?;
421
422        module.codegen_mut().compile_gradient(
423            calc_out_full,
424            &[
425                CompileGradientArgType::Const,
426                CompileGradientArgType::Const,
427                CompileGradientArgType::DupNoNeed,
428                CompileGradientArgType::DupNoNeed,
429                CompileGradientArgType::Const,
430                CompileGradientArgType::Const,
431            ],
432            CompileMode::ForwardSens,
433            "calc_out_sgrad",
434        )?;
435        module.codegen_mut().compile_gradient(
436            calc_out_full,
437            &[
438                CompileGradientArgType::Const,
439                CompileGradientArgType::Const,
440                CompileGradientArgType::DupNoNeed,
441                CompileGradientArgType::DupNoNeed,
442                CompileGradientArgType::Const,
443                CompileGradientArgType::Const,
444            ],
445            CompileMode::ReverseSens,
446            "calc_out_srgrad",
447        )?;
448
449        module.codegen_mut().compile_gradient(
450            rhs_full,
451            &[
452                CompileGradientArgType::Const,
453                CompileGradientArgType::Const,
454                CompileGradientArgType::DupNoNeed,
455                CompileGradientArgType::DupNoNeed,
456                CompileGradientArgType::Const,
457                CompileGradientArgType::Const,
458            ],
459            CompileMode::ReverseSens,
460            "rhs_srgrad",
461        )?;
462
463        module.post_autodiff_optimisation()?;
464        Ok(module)
465    }
466}
467
468impl CodegenModuleEmit for LlvmModule {
469    fn to_object(self) -> Result<Vec<u8>> {
470        let module = self.codegen().module();
471        //module.print_to_stderr();
472        let buffer = self
473            .machine
474            .write_to_memory_buffer(module, FileType::Object)
475            .unwrap()
476            .as_slice()
477            .to_vec();
478        Ok(buffer)
479    }
480}
481impl CodegenModuleJit for LlvmModule {
482    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
483        let ee = self
484            .codegen()
485            .module()
486            .create_jit_execution_engine(OptimizationLevel::Aggressive)
487            .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
488
489        let module = self.codegen().module();
490        let mut symbols = HashMap::new();
491        for function in module.get_functions() {
492            let name = function.get_name().to_str().unwrap();
493            let address = ee.get_function_address(name);
494            if let Ok(address) = address {
495                symbols.insert(name.to_string(), address as *const u8);
496            }
497        }
498        Ok(symbols)
499    }
500}
501
502struct Globals<'ctx> {
503    indices: Option<GlobalValue<'ctx>>,
504    constants: Option<GlobalValue<'ctx>>,
505    thread_counter: Option<GlobalValue<'ctx>>,
506}
507
508impl<'ctx> Globals<'ctx> {
509    fn new(
510        layout: &DataLayout,
511        module: &Module<'ctx>,
512        int_type: IntType<'ctx>,
513        real_type: FloatType<'ctx>,
514        threaded: bool,
515    ) -> Self {
516        let thread_counter = if threaded {
517            let tc = module.add_global(
518                int_type,
519                Some(AddressSpace::default()),
520                "enzyme_const_thread_counter",
521            );
522            // todo: for some reason this doesn't make enzyme think it's inactive
523            // but using enzyme_const in the name does
524            // todo: also, adding this metadata causes the print of the module to segfault,
525            // so maybe a bug in inkwell
526            //let md_string = context.metadata_string("enzyme_inactive");
527            //tc.set_metadata(md_string, 0);
528            let tc_value = int_type.const_zero();
529            tc.set_visibility(GlobalVisibility::Hidden);
530            tc.set_initializer(&tc_value.as_basic_value_enum());
531            Some(tc)
532        } else {
533            None
534        };
535        let constants = if layout.constants().is_empty() {
536            None
537        } else {
538            let constants_array_type =
539                real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
540            let constants = module.add_global(
541                constants_array_type,
542                Some(AddressSpace::default()),
543                "enzyme_const_constants",
544            );
545            constants.set_visibility(GlobalVisibility::Hidden);
546            constants.set_constant(false);
547            constants.set_initializer(&constants_array_type.const_zero());
548            Some(constants)
549        };
550        let indices = if layout.indices().is_empty() {
551            None
552        } else {
553            let indices_array_type =
554                int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
555            let indices_array_values = layout
556                .indices()
557                .iter()
558                .map(|&i| int_type.const_int(i as u64, true))
559                .collect::<Vec<IntValue>>();
560            let indices_value = int_type.const_array(indices_array_values.as_slice());
561            let indices = module.add_global(
562                indices_array_type,
563                Some(AddressSpace::default()),
564                "enzyme_const_indices",
565            );
566            indices.set_constant(true);
567            indices.set_visibility(GlobalVisibility::Hidden);
568            indices.set_initializer(&indices_value);
569            Some(indices)
570        };
571        Self {
572            indices,
573            thread_counter,
574            constants,
575        }
576    }
577}
578
579pub enum CompileGradientArgType {
580    Const,
581    Dup,
582    DupNoNeed,
583}
584
585pub enum CompileMode {
586    Forward,
587    ForwardSens,
588    Reverse,
589    ReverseSens,
590}
591
592pub struct CodeGen<'ctx> {
593    context: &'ctx inkwell::context::Context,
594    module: Module<'ctx>,
595    builder: Builder<'ctx>,
596    variables: HashMap<String, PointerValue<'ctx>>,
597    functions: HashMap<String, FunctionValue<'ctx>>,
598    fn_value_opt: Option<FunctionValue<'ctx>>,
599    tensor_ptr_opt: Option<PointerValue<'ctx>>,
600    diffsl_real_type: RealType,
601    real_type: FloatType<'ctx>,
602    real_ptr_type: PointerType<'ctx>,
603    int_type: IntType<'ctx>,
604    int_ptr_type: PointerType<'ctx>,
605    layout: DataLayout,
606    globals: Globals<'ctx>,
607    threaded: bool,
608    _ee: Option<ExecutionEngine<'ctx>>,
609}
610
611unsafe extern "C" fn fwd_handler(
612    _builder: LLVMBuilderRef,
613    _call_instruction: LLVMValueRef,
614    _gutils: *mut GradientUtils,
615    _dcall: *mut LLVMValueRef,
616    _normal_return: *mut LLVMValueRef,
617    _shadow_return: *mut LLVMValueRef,
618) -> u8 {
619    1
620}
621
622unsafe extern "C" fn rev_handler(
623    builder: LLVMBuilderRef,
624    call_instruction: LLVMValueRef,
625    gutils: *mut DiffeGradientUtils,
626    _tape: LLVMValueRef,
627) {
628    let call_block = LLVMGetInstructionParent(call_instruction);
629    let call_function = LLVMGetBasicBlockParent(call_block);
630    let module = LLVMGetGlobalParent(call_function);
631    let name_c_str = CString::new("barrier_grad").unwrap();
632    let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
633    let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
634    let barrier_num = LLVMGetArgOperand(call_instruction, 0);
635    let total_barriers = LLVMGetArgOperand(call_instruction, 1);
636    let thread_count = LLVMGetArgOperand(call_instruction, 2);
637    let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
638    let total_barriers =
639        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
640    let thread_count =
641        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
642    let mut args = [barrier_num, total_barriers, thread_count];
643    let name_c_str = CString::new("").unwrap();
644    LLVMBuildCall2(
645        builder,
646        barrier_func_type,
647        barrier_func,
648        args.as_mut_ptr(),
649        args.len() as u32,
650        name_c_str.as_ptr(),
651    );
652}
653
654#[allow(dead_code)]
655enum PrintValue<'ctx> {
656    Real(FloatValue<'ctx>),
657    Int(IntValue<'ctx>),
658}
659
660impl<'ctx> CodeGen<'ctx> {
661    pub fn new(
662        model: &DiscreteModel,
663        context: &'ctx inkwell::context::Context,
664        diffsl_real_type: RealType,
665        real_type: FloatType<'ctx>,
666        int_type: IntType<'ctx>,
667        threaded: bool,
668    ) -> Result<Self> {
669        let builder = context.create_builder();
670        let layout = DataLayout::new(model);
671        let module = context.create_module(model.name());
672        let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
673        let real_ptr_type = Self::pointer_type(context, real_type.into());
674        let int_ptr_type = Self::pointer_type(context, int_type.into());
675        let mut ret = Self {
676            context,
677            module,
678            builder,
679            real_type,
680            real_ptr_type,
681            variables: HashMap::new(),
682            functions: HashMap::new(),
683            fn_value_opt: None,
684            tensor_ptr_opt: None,
685            layout,
686            diffsl_real_type,
687            int_type,
688            int_ptr_type,
689            globals,
690            threaded,
691            _ee: None,
692        };
693        if threaded {
694            ret.compile_barrier_init()?;
695            ret.compile_barrier()?;
696            ret.compile_barrier_grad()?;
697            // todo: think I can remove this unless I want to call enzyme using a llvm pass
698            //ret.globals.add_registered_barrier(ret.context, &ret.module);
699        }
700        Ok(ret)
701    }
702
703    #[allow(dead_code)]
704    fn compile_print_value(
705        &mut self,
706        name: &str,
707        value: PrintValue<'ctx>,
708    ) -> Result<CallSiteValue<'_>> {
709        let void_type = self.context.void_type();
710        // int printf(const char *format, ...)
711        let printf_type = void_type.fn_type(&[self.int_ptr_type.into()], true);
712        // get printf function or declare it if it doesn't exist
713        let printf = match self.module.get_function("printf") {
714            Some(f) => f,
715            None => self
716                .module
717                .add_function("printf", printf_type, Some(Linkage::External)),
718        };
719        let (format_str, format_str_name) = match value {
720            PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
721            PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
722        };
723        // change format_str to c string
724        let format_str = CString::new(format_str).unwrap();
725        // if format_str_name doesn not already exist as a global, add it
726        let format_str_global = match self.module.get_global(format_str_name.as_str()) {
727            Some(g) => g,
728            None => {
729                let format_str = self.context.const_string(format_str.as_bytes(), true);
730                let fmt_str =
731                    self.module
732                        .add_global(format_str.get_type(), None, format_str_name.as_str());
733                fmt_str.set_initializer(&format_str);
734                fmt_str.set_visibility(GlobalVisibility::Hidden);
735                fmt_str
736            }
737        };
738        // call printf with the format string and the value
739        let format_str_ptr = self.builder.build_pointer_cast(
740            format_str_global.as_pointer_value(),
741            self.int_ptr_type,
742            "format_str_ptr",
743        )?;
744        let value: BasicMetadataValueEnum = match value {
745            PrintValue::Real(v) => v.into(),
746            PrintValue::Int(v) => v.into(),
747        };
748        self.builder
749            .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
750            .map_err(|e| anyhow!("Error building call to printf: {}", e))
751    }
752
753    fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
754        self.clear();
755        let void_type = self.context.void_type();
756        let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false);
757        let fn_arg_names = &["thread_id", "thread_dim"];
758        let function = self.module.add_function("set_constants", fn_type, None);
759
760        let basic_block = self.context.append_basic_block(function, "entry");
761        self.fn_value_opt = Some(function);
762        self.builder.position_at_end(basic_block);
763
764        for (i, arg) in function.get_param_iter().enumerate() {
765            let name = fn_arg_names[i];
766            let alloca = self.function_arg_alloca(name, arg);
767            self.insert_param(name, alloca);
768        }
769
770        self.insert_indices();
771        self.insert_constants(model);
772
773        let mut nbarriers = 0;
774        let total_barriers = (model.constant_defns().len()) as u64;
775        #[allow(clippy::explicit_counter_loop)]
776        for a in model.constant_defns() {
777            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
778            self.jit_compile_call_barrier(nbarriers, total_barriers);
779            nbarriers += 1;
780        }
781
782        self.builder.build_return(None)?;
783
784        if function.verify(true) {
785            Ok(function)
786        } else {
787            function.print_to_stderr();
788            unsafe {
789                function.delete();
790            }
791            Err(anyhow!("Invalid generated function."))
792        }
793    }
794
795    fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
796        self.clear();
797        let void_type = self.context.void_type();
798        let fn_type = void_type.fn_type(&[], false);
799        let function = self.module.add_function("barrier_init", fn_type, None);
800
801        let entry_block = self.context.append_basic_block(function, "entry");
802
803        self.fn_value_opt = Some(function);
804        self.builder.position_at_end(entry_block);
805
806        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
807        self.builder
808            .build_store(thread_counter, self.int_type.const_zero())?;
809
810        self.builder.build_return(None)?;
811
812        if function.verify(true) {
813            self.functions.insert("barrier_init".to_owned(), function);
814            Ok(function)
815        } else {
816            function.print_to_stderr();
817            unsafe {
818                function.delete();
819            }
820            Err(anyhow!("Invalid generated function."))
821        }
822    }
823
824    fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
825        self.clear();
826        let void_type = self.context.void_type();
827        let fn_type = void_type.fn_type(
828            &[
829                self.int_type.into(),
830                self.int_type.into(),
831                self.int_type.into(),
832            ],
833            false,
834        );
835        let function = self.module.add_function("barrier", fn_type, None);
836        let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
837        let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
838        function.add_attribute(AttributeLoc::Function, noinline);
839
840        let entry_block = self.context.append_basic_block(function, "entry");
841        let increment_block = self.context.append_basic_block(function, "increment");
842        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
843        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
844
845        self.fn_value_opt = Some(function);
846        self.builder.position_at_end(entry_block);
847
848        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
849        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
850        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
851        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
852
853        let nbarrier_equals_total_barriers = self
854            .builder
855            .build_int_compare(
856                IntPredicate::EQ,
857                barrier_num,
858                total_barriers,
859                "nbarrier_equals_total_barriers",
860            )
861            .unwrap();
862        // branch to barrier_done if nbarrier == total_barriers
863        self.builder.build_conditional_branch(
864            nbarrier_equals_total_barriers,
865            barrier_done_block,
866            increment_block,
867        )?;
868        self.builder.position_at_end(increment_block);
869
870        let barrier_num_times_thread_count = self
871            .builder
872            .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
873            .unwrap();
874
875        // Atomically increment the barrier counter
876        let i32_type = self.context.i32_type();
877        let one = i32_type.const_int(1, false);
878        self.builder.build_atomicrmw(
879            AtomicRMWBinOp::Add,
880            thread_counter,
881            one,
882            AtomicOrdering::Monotonic,
883        )?;
884
885        // wait_loop:
886        self.builder.build_unconditional_branch(wait_loop_block)?;
887        self.builder.position_at_end(wait_loop_block);
888
889        let current_value = self
890            .builder
891            .build_load(i32_type, thread_counter, "current_value")?
892            .into_int_value();
893
894        current_value
895            .as_instruction_value()
896            .unwrap()
897            .set_atomic_ordering(AtomicOrdering::Monotonic)
898            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
899
900        let all_threads_done = self.builder.build_int_compare(
901            IntPredicate::UGE,
902            current_value,
903            barrier_num_times_thread_count,
904            "all_threads_done",
905        )?;
906
907        self.builder.build_conditional_branch(
908            all_threads_done,
909            barrier_done_block,
910            wait_loop_block,
911        )?;
912        self.builder.position_at_end(barrier_done_block);
913
914        self.builder.build_return(None)?;
915
916        if function.verify(true) {
917            self.functions.insert("barrier".to_owned(), function);
918            Ok(function)
919        } else {
920            function.print_to_stderr();
921            unsafe {
922                function.delete();
923            }
924            Err(anyhow!("Invalid generated function."))
925        }
926    }
927
928    fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
929        self.clear();
930        let void_type = self.context.void_type();
931        let fn_type = void_type.fn_type(
932            &[
933                self.int_type.into(),
934                self.int_type.into(),
935                self.int_type.into(),
936            ],
937            false,
938        );
939        let function = self.module.add_function("barrier_grad", fn_type, None);
940
941        let entry_block = self.context.append_basic_block(function, "entry");
942        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
943        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
944
945        self.fn_value_opt = Some(function);
946        self.builder.position_at_end(entry_block);
947
948        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
949        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
950        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
951        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
952
953        let twice_total_barriers = self
954            .builder
955            .build_int_mul(
956                total_barriers,
957                self.int_type.const_int(2, false),
958                "twice_total_barriers",
959            )
960            .unwrap();
961        let twice_total_barriers_minus_barrier_num = self
962            .builder
963            .build_int_sub(
964                twice_total_barriers,
965                barrier_num,
966                "twice_total_barriers_minus_barrier_num",
967            )
968            .unwrap();
969        let twice_total_barriers_minus_barrier_num_times_thread_count = self
970            .builder
971            .build_int_mul(
972                twice_total_barriers_minus_barrier_num,
973                thread_count,
974                "twice_total_barriers_minus_barrier_num_times_thread_count",
975            )
976            .unwrap();
977
978        // Atomically increment the barrier counter
979        let i32_type = self.context.i32_type();
980        let one = i32_type.const_int(1, false);
981        self.builder.build_atomicrmw(
982            AtomicRMWBinOp::Add,
983            thread_counter,
984            one,
985            AtomicOrdering::Monotonic,
986        )?;
987
988        // wait_loop:
989        self.builder.build_unconditional_branch(wait_loop_block)?;
990        self.builder.position_at_end(wait_loop_block);
991
992        let current_value = self
993            .builder
994            .build_load(i32_type, thread_counter, "current_value")?
995            .into_int_value();
996        current_value
997            .as_instruction_value()
998            .unwrap()
999            .set_atomic_ordering(AtomicOrdering::Monotonic)
1000            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1001
1002        let all_threads_done = self.builder.build_int_compare(
1003            IntPredicate::UGE,
1004            current_value,
1005            twice_total_barriers_minus_barrier_num_times_thread_count,
1006            "all_threads_done",
1007        )?;
1008
1009        self.builder.build_conditional_branch(
1010            all_threads_done,
1011            barrier_done_block,
1012            wait_loop_block,
1013        )?;
1014        self.builder.position_at_end(barrier_done_block);
1015
1016        self.builder.build_return(None)?;
1017
1018        if function.verify(true) {
1019            self.functions.insert("barrier_grad".to_owned(), function);
1020            Ok(function)
1021        } else {
1022            function.print_to_stderr();
1023            unsafe {
1024                function.delete();
1025            }
1026            Err(anyhow!("Invalid generated function."))
1027        }
1028    }
1029
1030    fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1031        if !self.threaded {
1032            return;
1033        }
1034        let thread_dim = self.get_param("thread_dim");
1035        let thread_dim = self
1036            .builder
1037            .build_load(self.int_type, *thread_dim, "thread_dim")
1038            .unwrap()
1039            .into_int_value();
1040        let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1041        let total_barriers = self.int_type.const_int(total_barriers, false);
1042        let barrier = self.get_function("barrier").unwrap();
1043        self.builder
1044            .build_call(
1045                barrier,
1046                &[
1047                    BasicMetadataValueEnum::IntValue(nbarrier),
1048                    BasicMetadataValueEnum::IntValue(total_barriers),
1049                    BasicMetadataValueEnum::IntValue(thread_dim),
1050                ],
1051                "barrier",
1052            )
1053            .unwrap();
1054    }
1055
1056    fn jit_threading_limits(
1057        &mut self,
1058        size: IntValue<'ctx>,
1059    ) -> Result<(
1060        IntValue<'ctx>,
1061        IntValue<'ctx>,
1062        BasicBlock<'ctx>,
1063        BasicBlock<'ctx>,
1064    )> {
1065        let one = self.int_type.const_int(1, false);
1066        let thread_id = self.get_param("thread_id");
1067        let thread_id = self
1068            .builder
1069            .build_load(self.int_type, *thread_id, "thread_id")
1070            .unwrap()
1071            .into_int_value();
1072        let thread_dim = self.get_param("thread_dim");
1073        let thread_dim = self
1074            .builder
1075            .build_load(self.int_type, *thread_dim, "thread_dim")
1076            .unwrap()
1077            .into_int_value();
1078
1079        // start index is i * size / thread_dim
1080        let i_times_size = self
1081            .builder
1082            .build_int_mul(thread_id, size, "i_times_size")?;
1083        let start = self
1084            .builder
1085            .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1086
1087        // the ending index for thread i is (i+1) * size / thread_dim
1088        let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1089        let i_plus_one_times_size =
1090            self.builder
1091                .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1092        let end = self
1093            .builder
1094            .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1095
1096        let test_done = self.builder.get_insert_block().unwrap();
1097        let next_block = self
1098            .context
1099            .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1100        self.builder.position_at_end(next_block);
1101
1102        Ok((start, end, test_done, next_block))
1103    }
1104
1105    fn jit_end_threading(
1106        &mut self,
1107        start: IntValue<'ctx>,
1108        end: IntValue<'ctx>,
1109        test_done: BasicBlock<'ctx>,
1110        next: BasicBlock<'ctx>,
1111    ) -> Result<()> {
1112        let exit = self
1113            .context
1114            .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1115        self.builder.build_unconditional_branch(exit)?;
1116        self.builder.position_at_end(test_done);
1117        // done if start == end
1118        let done = self
1119            .builder
1120            .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1121        self.builder.build_conditional_branch(done, exit, next)?;
1122        self.builder.position_at_end(exit);
1123        Ok(())
1124    }
1125
1126    pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1127        self.module.write_bitcode_to_path(path);
1128    }
1129
1130    fn insert_constants(&mut self, model: &DiscreteModel) {
1131        if let Some(constants) = self.globals.constants.as_ref() {
1132            self.insert_param("constants", constants.as_pointer_value());
1133            for tensor in model.constant_defns() {
1134                self.insert_tensor(tensor, true);
1135            }
1136        }
1137    }
1138
1139    fn insert_data(&mut self, model: &DiscreteModel) {
1140        self.insert_constants(model);
1141
1142        if let Some(input) = model.input() {
1143            self.insert_tensor(input, false);
1144        }
1145        for tensor in model.input_dep_defns() {
1146            self.insert_tensor(tensor, false);
1147        }
1148        for tensor in model.time_dep_defns() {
1149            self.insert_tensor(tensor, false);
1150        }
1151        for tensor in model.state_dep_defns() {
1152            self.insert_tensor(tensor, false);
1153        }
1154    }
1155
1156    fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1157        context.ptr_type(AddressSpace::default())
1158    }
1159
1160    fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1161        context.ptr_type(AddressSpace::default())
1162    }
1163
1164    fn insert_indices(&mut self) {
1165        if let Some(indices) = self.globals.indices.as_ref() {
1166            let i32_type = self.context.i32_type();
1167            let zero = i32_type.const_int(0, false);
1168            let ptr = unsafe {
1169                indices
1170                    .as_pointer_value()
1171                    .const_in_bounds_gep(i32_type, &[zero])
1172            };
1173            self.variables.insert("indices".to_owned(), ptr);
1174        }
1175    }
1176
1177    fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1178        self.variables.insert(name.to_owned(), value);
1179    }
1180
1181    fn build_gep<T: BasicType<'ctx>>(
1182        &self,
1183        ty: T,
1184        ptr: PointerValue<'ctx>,
1185        ordered_indexes: &[IntValue<'ctx>],
1186        name: &str,
1187    ) -> Result<PointerValue<'ctx>> {
1188        unsafe {
1189            self.builder
1190                .build_gep(ty, ptr, ordered_indexes, name)
1191                .map_err(|e| e.into())
1192        }
1193    }
1194
1195    fn build_load<T: BasicType<'ctx>>(
1196        &self,
1197        ty: T,
1198        ptr: PointerValue<'ctx>,
1199        name: &str,
1200    ) -> Result<BasicValueEnum<'ctx>> {
1201        self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1202    }
1203
1204    fn get_ptr_to_index<T: BasicType<'ctx>>(
1205        builder: &Builder<'ctx>,
1206        ty: T,
1207        ptr: &PointerValue<'ctx>,
1208        index: IntValue<'ctx>,
1209        name: &str,
1210    ) -> PointerValue<'ctx> {
1211        unsafe {
1212            builder
1213                .build_in_bounds_gep(ty, *ptr, &[index], name)
1214                .unwrap()
1215        }
1216    }
1217
1218    fn insert_state(&mut self, u: &Tensor) {
1219        let mut data_index = 0;
1220        for blk in u.elmts() {
1221            if let Some(name) = blk.name() {
1222                let ptr = self.variables.get("u").unwrap();
1223                let i = self
1224                    .context
1225                    .i32_type()
1226                    .const_int(data_index.try_into().unwrap(), false);
1227                let alloca = Self::get_ptr_to_index(
1228                    &self.create_entry_block_builder(),
1229                    self.real_type,
1230                    ptr,
1231                    i,
1232                    blk.name().unwrap(),
1233                );
1234                self.variables.insert(name.to_owned(), alloca);
1235            }
1236            data_index += blk.nnz();
1237        }
1238    }
1239    fn insert_dot_state(&mut self, dudt: &Tensor) {
1240        let mut data_index = 0;
1241        for blk in dudt.elmts() {
1242            if let Some(name) = blk.name() {
1243                let ptr = self.variables.get("dudt").unwrap();
1244                let i = self
1245                    .context
1246                    .i32_type()
1247                    .const_int(data_index.try_into().unwrap(), false);
1248                let alloca = Self::get_ptr_to_index(
1249                    &self.create_entry_block_builder(),
1250                    self.real_type,
1251                    ptr,
1252                    i,
1253                    blk.name().unwrap(),
1254                );
1255                self.variables.insert(name.to_owned(), alloca);
1256            }
1257            data_index += blk.nnz();
1258        }
1259    }
1260    fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1261        let var_name = if is_constant { "constants" } else { "data" };
1262        let ptr = *self.variables.get(var_name).unwrap();
1263        let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1264        let i = self
1265            .context
1266            .i32_type()
1267            .const_int(data_index.try_into().unwrap(), false);
1268        let alloca = Self::get_ptr_to_index(
1269            &self.create_entry_block_builder(),
1270            self.real_type,
1271            &ptr,
1272            i,
1273            tensor.name(),
1274        );
1275        self.variables.insert(tensor.name().to_owned(), alloca);
1276
1277        //insert any named blocks
1278        for blk in tensor.elmts() {
1279            if let Some(name) = blk.name() {
1280                let i = self
1281                    .context
1282                    .i32_type()
1283                    .const_int(data_index.try_into().unwrap(), false);
1284                let alloca = Self::get_ptr_to_index(
1285                    &self.create_entry_block_builder(),
1286                    self.real_type,
1287                    &ptr,
1288                    i,
1289                    name,
1290                );
1291                self.variables.insert(name.to_owned(), alloca);
1292            }
1293            // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index
1294            data_index += blk.nnz();
1295        }
1296    }
1297    fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1298        self.variables.get(name).unwrap()
1299    }
1300
1301    fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1302        self.variables.get(tensor.name()).unwrap()
1303    }
1304
1305    fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1306        match self.functions.get(name) {
1307            Some(&func) => Some(func),
1308            None => {
1309                let function = match name {
1310                    // support some llvm intrinsics
1311                    "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs"
1312                    | "copysign" | "pow" | "min" | "max" => {
1313                        let arg_len = 1;
1314                        let intrinsic_name = match name {
1315                            "min" => "minnum",
1316                            "max" => "maxnum",
1317                            "abs" => "fabs",
1318                            _ => name,
1319                        };
1320                        let llvm_name =
1321                            format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
1322                        let intrinsic = Intrinsic::find(&llvm_name).unwrap();
1323                        let ret_type = self.real_type;
1324
1325                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1326                            .map(|f| f.into())
1327                            .collect::<Vec<BasicTypeEnum>>();
1328                        // if we get an intrinsic, we don't need to add to the list of functions and can return early
1329                        return intrinsic.get_declaration(&self.module, args_types.as_slice());
1330                    }
1331                    // some custom functions
1332                    "sigmoid" => {
1333                        let arg_len = 1;
1334                        let ret_type = self.real_type;
1335
1336                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1337                            .map(|f| f.into())
1338                            .collect::<Vec<BasicMetadataTypeEnum>>();
1339
1340                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1341                        let fn_val = self.module.add_function(name, fn_type, None);
1342
1343                        for arg in fn_val.get_param_iter() {
1344                            arg.into_float_value().set_name("x");
1345                        }
1346
1347                        let current_block = self.builder.get_insert_block().unwrap();
1348                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1349                        self.builder.position_at_end(basic_block);
1350                        let x = fn_val.get_nth_param(0)?.into_float_value();
1351                        let one = self.real_type.const_float(1.0);
1352                        let negx = self.builder.build_float_neg(x, name).ok()?;
1353                        let exp = self.get_function("exp").unwrap();
1354                        let exp_negx = self
1355                            .builder
1356                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1357                            .ok()?;
1358                        let one_plus_exp_negx = self
1359                            .builder
1360                            .build_float_add(
1361                                exp_negx
1362                                    .try_as_basic_value()
1363                                    .unwrap_basic()
1364                                    .into_float_value(),
1365                                one,
1366                                name,
1367                            )
1368                            .ok()?;
1369                        let sigmoid = self
1370                            .builder
1371                            .build_float_div(one, one_plus_exp_negx, name)
1372                            .ok()?;
1373                        self.builder.build_return(Some(&sigmoid)).ok();
1374                        self.builder.position_at_end(current_block);
1375                        Some(fn_val)
1376                    }
1377                    "arcsinh" | "arccosh" => {
1378                        let arg_len = 1;
1379                        let ret_type = self.real_type;
1380
1381                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1382                            .map(|f| f.into())
1383                            .collect::<Vec<BasicMetadataTypeEnum>>();
1384
1385                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1386                        let fn_val = self.module.add_function(name, fn_type, None);
1387
1388                        for arg in fn_val.get_param_iter() {
1389                            arg.into_float_value().set_name("x");
1390                        }
1391
1392                        let current_block = self.builder.get_insert_block().unwrap();
1393                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1394                        self.builder.position_at_end(basic_block);
1395                        let x = fn_val.get_nth_param(0)?.into_float_value();
1396                        let one = match name {
1397                            "arccosh" => self.real_type.const_float(-1.0),
1398                            "arcsinh" => self.real_type.const_float(1.0),
1399                            _ => panic!("unknown function"),
1400                        };
1401                        let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1402                        let one_plus_x_squared =
1403                            self.builder.build_float_add(x_squared, one, name).ok()?;
1404                        let sqrt = self.get_function("sqrt").unwrap();
1405                        let sqrt_one_plus_x_squared = self
1406                            .builder
1407                            .build_call(
1408                                sqrt,
1409                                &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1410                                name,
1411                            )
1412                            .unwrap()
1413                            .try_as_basic_value()
1414                            .unwrap_basic()
1415                            .into_float_value();
1416                        let x_plus_sqrt_one_plus_x_squared = self
1417                            .builder
1418                            .build_float_add(x, sqrt_one_plus_x_squared, name)
1419                            .ok()?;
1420                        let ln = self.get_function("log").unwrap();
1421                        let result = self
1422                            .builder
1423                            .build_call(
1424                                ln,
1425                                &[BasicMetadataValueEnum::FloatValue(
1426                                    x_plus_sqrt_one_plus_x_squared,
1427                                )],
1428                                name,
1429                            )
1430                            .unwrap()
1431                            .try_as_basic_value()
1432                            .unwrap_basic()
1433                            .into_float_value();
1434                        self.builder.build_return(Some(&result)).ok();
1435                        self.builder.position_at_end(current_block);
1436                        Some(fn_val)
1437                    }
1438                    "heaviside" => {
1439                        let arg_len = 1;
1440                        let ret_type = self.real_type;
1441
1442                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1443                            .map(|f| f.into())
1444                            .collect::<Vec<BasicMetadataTypeEnum>>();
1445
1446                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1447                        let fn_val = self.module.add_function(name, fn_type, None);
1448
1449                        for arg in fn_val.get_param_iter() {
1450                            arg.into_float_value().set_name("x");
1451                        }
1452
1453                        let current_block = self.builder.get_insert_block().unwrap();
1454                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1455                        self.builder.position_at_end(basic_block);
1456                        let x = fn_val.get_nth_param(0)?.into_float_value();
1457                        let zero = self.real_type.const_float(0.0);
1458                        let one = self.real_type.const_float(1.0);
1459                        let result = self
1460                            .builder
1461                            .build_select(
1462                                self.builder
1463                                    .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1464                                    .unwrap(),
1465                                one,
1466                                zero,
1467                                name,
1468                            )
1469                            .ok()?;
1470                        self.builder.build_return(Some(&result)).ok();
1471                        self.builder.position_at_end(current_block);
1472                        Some(fn_val)
1473                    }
1474                    "tanh" | "sinh" | "cosh" => {
1475                        let arg_len = 1;
1476                        let ret_type = self.real_type;
1477
1478                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1479                            .map(|f| f.into())
1480                            .collect::<Vec<BasicMetadataTypeEnum>>();
1481
1482                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1483                        let fn_val = self.module.add_function(name, fn_type, None);
1484
1485                        for arg in fn_val.get_param_iter() {
1486                            arg.into_float_value().set_name("x");
1487                        }
1488
1489                        let current_block = self.builder.get_insert_block().unwrap();
1490                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1491                        self.builder.position_at_end(basic_block);
1492                        let x = fn_val.get_nth_param(0)?.into_float_value();
1493                        let negx = self.builder.build_float_neg(x, name).ok()?;
1494                        let exp = self.get_function("exp").unwrap();
1495                        let exp_negx = self
1496                            .builder
1497                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1498                            .ok()?;
1499                        let expx = self
1500                            .builder
1501                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1502                            .ok()?;
1503                        let expx_minus_exp_negx = self
1504                            .builder
1505                            .build_float_sub(
1506                                expx.try_as_basic_value().unwrap_basic().into_float_value(),
1507                                exp_negx
1508                                    .try_as_basic_value()
1509                                    .unwrap_basic()
1510                                    .into_float_value(),
1511                                name,
1512                            )
1513                            .ok()?;
1514                        let expx_plus_exp_negx = self
1515                            .builder
1516                            .build_float_add(
1517                                expx.try_as_basic_value().unwrap_basic().into_float_value(),
1518                                exp_negx
1519                                    .try_as_basic_value()
1520                                    .unwrap_basic()
1521                                    .into_float_value(),
1522                                name,
1523                            )
1524                            .ok()?;
1525                        let result = match name {
1526                            "tanh" => self
1527                                .builder
1528                                .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1529                                .ok()?,
1530                            "sinh" => self
1531                                .builder
1532                                .build_float_div(
1533                                    expx_minus_exp_negx,
1534                                    self.real_type.const_float(2.0),
1535                                    name,
1536                                )
1537                                .ok()?,
1538                            "cosh" => self
1539                                .builder
1540                                .build_float_div(
1541                                    expx_plus_exp_negx,
1542                                    self.real_type.const_float(2.0),
1543                                    name,
1544                                )
1545                                .ok()?,
1546                            _ => panic!("unknown function"),
1547                        };
1548                        self.builder.build_return(Some(&result)).ok();
1549                        self.builder.position_at_end(current_block);
1550                        Some(fn_val)
1551                    }
1552                    _ => None,
1553                }?;
1554                self.functions.insert(name.to_owned(), function);
1555                Some(function)
1556            }
1557        }
1558    }
1559    /// Returns the `FunctionValue` representing the function being compiled.
1560    #[inline]
1561    fn fn_value(&self) -> FunctionValue<'ctx> {
1562        self.fn_value_opt.unwrap()
1563    }
1564
1565    #[inline]
1566    fn tensor_ptr(&self) -> PointerValue<'ctx> {
1567        self.tensor_ptr_opt.unwrap()
1568    }
1569
1570    /// Creates a new builder in the entry block of the function.
1571    fn create_entry_block_builder(&self) -> Builder<'ctx> {
1572        let builder = self.context.create_builder();
1573        let entry = self.fn_value().get_first_basic_block().unwrap();
1574        match entry.get_first_instruction() {
1575            Some(first_instr) => builder.position_before(&first_instr),
1576            None => builder.position_at_end(entry),
1577        }
1578        builder
1579    }
1580
1581    fn jit_compile_scalar(
1582        &mut self,
1583        a: &Tensor,
1584        res_ptr_opt: Option<PointerValue<'ctx>>,
1585    ) -> Result<PointerValue<'ctx>> {
1586        let res_type = self.real_type;
1587        let res_ptr = match res_ptr_opt {
1588            Some(ptr) => ptr,
1589            None => self
1590                .create_entry_block_builder()
1591                .build_alloca(res_type, a.name())?,
1592        };
1593        let name = a.name();
1594        let elmt = a.elmts().first().unwrap();
1595
1596        // if threaded then only the first thread will evaluate the scalar
1597        let curr_block = self.builder.get_insert_block().unwrap();
1598        let mut next_block_opt = None;
1599        if self.threaded {
1600            let next_block = self.context.append_basic_block(self.fn_value(), "next");
1601            self.builder.position_at_end(next_block);
1602            next_block_opt = Some(next_block);
1603        }
1604
1605        let zero = self.int_type.const_zero();
1606        let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
1607        self.builder.build_store(res_ptr, float_value)?;
1608
1609        // complete the threading block
1610        if self.threaded {
1611            let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1612            self.builder.build_unconditional_branch(exit_block)?;
1613            self.builder.position_at_end(curr_block);
1614
1615            let thread_id = self.get_param("thread_id");
1616            let thread_id = self
1617                .builder
1618                .build_load(self.int_type, *thread_id, "thread_id")
1619                .unwrap()
1620                .into_int_value();
1621            let is_first_thread = self.builder.build_int_compare(
1622                IntPredicate::EQ,
1623                thread_id,
1624                self.int_type.const_zero(),
1625                "is_first_thread",
1626            )?;
1627            self.builder.build_conditional_branch(
1628                is_first_thread,
1629                next_block_opt.unwrap(),
1630                exit_block,
1631            )?;
1632
1633            self.builder.position_at_end(exit_block);
1634        }
1635
1636        Ok(res_ptr)
1637    }
1638
1639    fn jit_compile_tensor(
1640        &mut self,
1641        a: &Tensor,
1642        res_ptr_opt: Option<PointerValue<'ctx>>,
1643    ) -> Result<PointerValue<'ctx>> {
1644        // treat scalar as a special case
1645        if a.rank() == 0 {
1646            return self.jit_compile_scalar(a, res_ptr_opt);
1647        }
1648
1649        let res_type = self.real_type;
1650        let res_ptr = match res_ptr_opt {
1651            Some(ptr) => ptr,
1652            None => self
1653                .create_entry_block_builder()
1654                .build_alloca(res_type, a.name())?,
1655        };
1656
1657        // set up the tensor storage pointer and index into this data
1658        self.tensor_ptr_opt = Some(res_ptr);
1659
1660        for (i, blk) in a.elmts().iter().enumerate() {
1661            let default = format!("{}-{}", a.name(), i);
1662            let name = blk.name().unwrap_or(default.as_str());
1663            self.jit_compile_block(name, a, blk)?;
1664        }
1665        Ok(res_ptr)
1666    }
1667
1668    fn jit_compile_block(&mut self, name: &str, tensor: &Tensor, elmt: &TensorBlock) -> Result<()> {
1669        let translation = Translation::new(
1670            elmt.expr_layout(),
1671            elmt.layout(),
1672            elmt.start(),
1673            tensor.layout_ptr(),
1674        );
1675
1676        if elmt.expr_layout().is_dense() {
1677            self.jit_compile_dense_block(name, elmt, &translation)
1678        } else if elmt.expr_layout().is_diagonal() {
1679            self.jit_compile_diagonal_block(name, elmt, &translation)
1680        } else if elmt.expr_layout().is_sparse() {
1681            match translation.source {
1682                TranslationFrom::SparseContraction { .. } => {
1683                    self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1684                }
1685                _ => self.jit_compile_sparse_block(name, elmt, &translation),
1686            }
1687        } else {
1688            Err(anyhow!(
1689                "unsupported block layout: {:?}",
1690                elmt.expr_layout()
1691            ))
1692        }
1693    }
1694
1695    // for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index
1696    fn jit_compile_dense_block(
1697        &mut self,
1698        name: &str,
1699        elmt: &TensorBlock,
1700        translation: &Translation,
1701    ) -> Result<()> {
1702        let int_type = self.int_type;
1703
1704        let mut preblock = self.builder.get_insert_block().unwrap();
1705        let expr_rank = elmt.expr_layout().rank();
1706        let expr_shape = elmt
1707            .expr_layout()
1708            .shape()
1709            .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1710        let one = int_type.const_int(1, false);
1711
1712        let mut expr_strides = vec![1; expr_rank];
1713        if expr_rank > 0 {
1714            for i in (0..expr_rank - 1).rev() {
1715                expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1716            }
1717        }
1718        let expr_strides = expr_strides
1719            .iter()
1720            .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1721            .collect::<Vec<IntValue>>();
1722
1723        // setup indices, loop through the nested loops
1724        let mut indices = Vec::new();
1725        let mut blocks = Vec::new();
1726
1727        // allocate the contract sum if needed
1728        let (contract_sum, contract_by, contract_strides) =
1729            if let TranslationFrom::DenseContraction {
1730                contract_by,
1731                contract_len: _,
1732            } = translation.source
1733            {
1734                let contract_rank = expr_rank - contract_by;
1735                let mut contract_strides = vec![1; contract_rank];
1736                for i in (0..contract_rank - 1).rev() {
1737                    contract_strides[i] =
1738                        contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1739                }
1740                let contract_strides = contract_strides
1741                    .iter()
1742                    .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1743                    .collect::<Vec<IntValue>>();
1744                (
1745                    Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1746                    contract_by,
1747                    Some(contract_strides),
1748                )
1749            } else {
1750                (None, 0, None)
1751            };
1752
1753        // we will thread the output loop, except if we are contracting to a scalar
1754        let (thread_start, thread_end, test_done, next) = if self.threaded {
1755            let (start, end, test_done, next) =
1756                self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1757            preblock = next;
1758            (Some(start), Some(end), Some(test_done), Some(next))
1759        } else {
1760            (None, None, None, None)
1761        };
1762
1763        for i in 0..expr_rank {
1764            let block = self.context.append_basic_block(self.fn_value(), name);
1765            self.builder.build_unconditional_branch(block)?;
1766            self.builder.position_at_end(block);
1767
1768            let start_index = if i == 0 && self.threaded {
1769                thread_start.unwrap()
1770            } else {
1771                self.int_type.const_zero()
1772            };
1773
1774            let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1775            curr_index.add_incoming(&[(&start_index, preblock)]);
1776
1777            if i == expr_rank - contract_by - 1 {
1778                if let Some(contract_sum) = contract_sum {
1779                    self.builder
1780                        .build_store(contract_sum, self.real_type.const_zero())?;
1781                }
1782            }
1783
1784            indices.push(curr_index);
1785            blocks.push(block);
1786            preblock = block;
1787        }
1788
1789        let indices_int: Vec<IntValue> = indices
1790            .iter()
1791            .map(|i| i.as_basic_value().into_int_value())
1792            .collect();
1793
1794        // if indices = (i, j, k) and shape = (a, b, c) calculate expr_index = (k + j*b + i*b*c)
1795        let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
1796        let mut stride = 1u64;
1797        if !indices.is_empty() {
1798            for i in (0..indices.len() - 1).rev() {
1799                let iname_i = indices_int[i];
1800                let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
1801                stride *= shapei;
1802                let stride_intval = self.context.i32_type().const_int(stride, false);
1803                let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
1804                expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
1805            }
1806        }
1807
1808        let float_value =
1809            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
1810
1811        if let Some(contract_sum) = contract_sum {
1812            let contract_sum_value = self
1813                .build_load(self.real_type, contract_sum, "contract_sum")?
1814                .into_float_value();
1815            let new_contract_sum_value = self.builder.build_float_add(
1816                contract_sum_value,
1817                float_value,
1818                "new_contract_sum",
1819            )?;
1820            self.builder
1821                .build_store(contract_sum, new_contract_sum_value)?;
1822        } else {
1823            let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
1824                self.int_type.const_zero(),
1825                |acc, (i, s)| {
1826                    let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
1827                    self.builder.build_int_add(acc, tmp, "acc").unwrap()
1828                },
1829            );
1830            self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
1831        }
1832
1833        let mut postblock = self.builder.get_insert_block().unwrap();
1834
1835        // unwind the nested loops
1836        for i in (0..expr_rank).rev() {
1837            // increment index
1838            let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
1839            indices[i].add_incoming(&[(&next_index, postblock)]);
1840            if i == expr_rank - contract_by - 1 {
1841                if let Some(contract_sum) = contract_sum {
1842                    let contract_sum_value = self
1843                        .build_load(self.real_type, contract_sum, "contract_sum")?
1844                        .into_float_value();
1845                    let contract_strides = contract_strides.as_ref().unwrap();
1846                    let elmt_index = indices_int
1847                        .iter()
1848                        .take(contract_strides.len())
1849                        .zip(contract_strides.iter())
1850                        .fold(self.int_type.const_zero(), |acc, (i, s)| {
1851                            let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
1852                            self.builder.build_int_add(acc, tmp, "acc").unwrap()
1853                        });
1854                    self.jit_compile_store(
1855                        name,
1856                        elmt,
1857                        elmt_index,
1858                        contract_sum_value,
1859                        translation,
1860                    )?;
1861                }
1862            }
1863
1864            let end_index = if i == 0 && self.threaded {
1865                thread_end.unwrap()
1866            } else {
1867                expr_shape[i]
1868            };
1869
1870            // loop condition
1871            let loop_while =
1872                self.builder
1873                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
1874            let block = self.context.append_basic_block(self.fn_value(), name);
1875            self.builder
1876                .build_conditional_branch(loop_while, blocks[i], block)?;
1877            self.builder.position_at_end(block);
1878            postblock = block;
1879        }
1880
1881        if self.threaded {
1882            self.jit_end_threading(
1883                thread_start.unwrap(),
1884                thread_end.unwrap(),
1885                test_done.unwrap(),
1886                next.unwrap(),
1887            )?;
1888        }
1889        Ok(())
1890    }
1891
1892    fn jit_compile_sparse_contraction_block(
1893        &mut self,
1894        name: &str,
1895        elmt: &TensorBlock,
1896        translation: &Translation,
1897    ) -> Result<()> {
1898        match translation.source {
1899            TranslationFrom::SparseContraction { .. } => {}
1900            _ => {
1901                panic!("expected sparse contraction")
1902            }
1903        }
1904        let int_type = self.int_type;
1905
1906        let translation_index = self
1907            .layout
1908            .get_translation_index(elmt.expr_layout(), elmt.layout())
1909            .unwrap();
1910        let translation_index = translation_index + translation.get_from_index_in_data_layout();
1911
1912        let final_contract_index =
1913            int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
1914        let (thread_start, thread_end, test_done, next) = if self.threaded {
1915            let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
1916            (Some(start), Some(end), Some(test_done), Some(next))
1917        } else {
1918            (None, None, None, None)
1919        };
1920
1921        let preblock = self.builder.get_insert_block().unwrap();
1922        let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
1923
1924        // loop through each contraction
1925        let block = self.context.append_basic_block(self.fn_value(), name);
1926        self.builder.build_unconditional_branch(block)?;
1927        self.builder.position_at_end(block);
1928
1929        let contract_index = self.builder.build_phi(int_type, "i")?;
1930        let contract_start = if self.threaded {
1931            thread_start.unwrap()
1932        } else {
1933            int_type.const_zero()
1934        };
1935        contract_index.add_incoming(&[(&contract_start, preblock)]);
1936
1937        let start_index = self.builder.build_int_add(
1938            int_type.const_int(translation_index.try_into().unwrap(), false),
1939            self.builder.build_int_mul(
1940                int_type.const_int(2, false),
1941                contract_index.as_basic_value().into_int_value(),
1942                name,
1943            )?,
1944            name,
1945        )?;
1946        let end_index =
1947            self.builder
1948                .build_int_add(start_index, int_type.const_int(1, false), name)?;
1949        let start_ptr = self.build_gep(
1950            self.int_type,
1951            *self.get_param("indices"),
1952            &[start_index],
1953            "start_index_ptr",
1954        )?;
1955        let start_contract = self
1956            .build_load(self.int_type, start_ptr, "start")?
1957            .into_int_value();
1958        let end_ptr = self.build_gep(
1959            self.int_type,
1960            *self.get_param("indices"),
1961            &[end_index],
1962            "end_index_ptr",
1963        )?;
1964        let end_contract = self
1965            .build_load(self.int_type, end_ptr, "end")?
1966            .into_int_value();
1967
1968        // initialise the contract sum
1969        self.builder
1970            .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
1971
1972        // loop through each element in the contraction
1973        let start_contract_block = self
1974            .context
1975            .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
1976        self.builder
1977            .build_unconditional_branch(start_contract_block)?;
1978        self.builder.position_at_end(start_contract_block);
1979
1980        let expr_index_phi = self.builder.build_phi(int_type, "j")?;
1981        expr_index_phi.add_incoming(&[(&start_contract, block)]);
1982
1983        let expr_index = expr_index_phi.as_basic_value().into_int_value();
1984        let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
1985
1986        // loop body - eval expression and increment sum
1987        let float_value =
1988            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
1989        let contract_sum_value = self
1990            .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
1991            .into_float_value();
1992        let new_contract_sum_value =
1993            self.builder
1994                .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
1995        self.builder
1996            .build_store(contract_sum_ptr, new_contract_sum_value)?;
1997
1998        let end_contract_block = self.builder.get_insert_block().unwrap();
1999
2000        // increment contract loop index
2001        let next_elmt_index =
2002            self.builder
2003                .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2004        expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
2005
2006        // contract loop condition
2007        let loop_while = self.builder.build_int_compare(
2008            IntPredicate::ULT,
2009            next_elmt_index,
2010            end_contract,
2011            name,
2012        )?;
2013        let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2014        self.builder.build_conditional_branch(
2015            loop_while,
2016            start_contract_block,
2017            post_contract_block,
2018        )?;
2019        self.builder.position_at_end(post_contract_block);
2020
2021        // store the result
2022        self.jit_compile_store(
2023            name,
2024            elmt,
2025            contract_index.as_basic_value().into_int_value(),
2026            new_contract_sum_value,
2027            translation,
2028        )?;
2029
2030        // increment outer loop index
2031        let next_contract_index = self.builder.build_int_add(
2032            contract_index.as_basic_value().into_int_value(),
2033            int_type.const_int(1, false),
2034            name,
2035        )?;
2036        contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2037
2038        // outer loop condition
2039        let loop_while = self.builder.build_int_compare(
2040            IntPredicate::ULT,
2041            next_contract_index,
2042            thread_end.unwrap_or(final_contract_index),
2043            name,
2044        )?;
2045        let post_block = self.context.append_basic_block(self.fn_value(), name);
2046        self.builder
2047            .build_conditional_branch(loop_while, block, post_block)?;
2048        self.builder.position_at_end(post_block);
2049
2050        if self.threaded {
2051            self.jit_end_threading(
2052                thread_start.unwrap(),
2053                thread_end.unwrap(),
2054                test_done.unwrap(),
2055                next.unwrap(),
2056            )?;
2057        }
2058
2059        Ok(())
2060    }
2061
2062    fn expr_indices_from_elmt_index(
2063        &mut self,
2064        elmt_index: IntValue<'ctx>,
2065        elmt: &TensorBlock,
2066        name: &str,
2067    ) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
2068        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2069        let int_type = self.int_type;
2070        // loop body - load index from layout
2071        let elmt_index_mult_rank = self.builder.build_int_mul(
2072            elmt_index,
2073            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2074            name,
2075        )?;
2076        (0..elmt.expr_layout().rank())
2077            .map(|i| {
2078                let layout_index_plus_offset =
2079                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
2080                let curr_index = self.builder.build_int_add(
2081                    elmt_index_mult_rank,
2082                    layout_index_plus_offset,
2083                    name,
2084                )?;
2085                let ptr = Self::get_ptr_to_index(
2086                    &self.builder,
2087                    self.int_type,
2088                    self.get_param("indices"),
2089                    curr_index,
2090                    name,
2091                );
2092                Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2093            })
2094            .collect::<Result<Vec<_>, anyhow::Error>>()
2095    }
2096
2097    // for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index
2098    // TODO: havn't implemented contractions yet
2099    fn jit_compile_sparse_block(
2100        &mut self,
2101        name: &str,
2102        elmt: &TensorBlock,
2103        translation: &Translation,
2104    ) -> Result<()> {
2105        let int_type = self.int_type;
2106
2107        let start_index = int_type.const_int(0, false);
2108        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2109
2110        let (thread_start, thread_end, test_done, next) = if self.threaded {
2111            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2112            (Some(start), Some(end), Some(test_done), Some(next))
2113        } else {
2114            (None, None, None, None)
2115        };
2116
2117        // loop through the non-zero elements
2118        let preblock = self.builder.get_insert_block().unwrap();
2119        let loop_block = self.context.append_basic_block(self.fn_value(), name);
2120        self.builder.build_unconditional_branch(loop_block)?;
2121        self.builder.position_at_end(loop_block);
2122
2123        let curr_index = self.builder.build_phi(int_type, "i")?;
2124        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2125
2126        let elmt_index = curr_index.as_basic_value().into_int_value();
2127        let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
2128
2129        // loop body - eval expression
2130        let float_value =
2131            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2132
2133        self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2134
2135        // jit_compile_expr or jit_compile_broadcast_and_store may have changed the current block
2136        let end_loop_block = self.builder.get_insert_block().unwrap();
2137
2138        // increment loop index
2139        let one = int_type.const_int(1, false);
2140        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2141        curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2142
2143        // loop condition
2144        let loop_while = self.builder.build_int_compare(
2145            IntPredicate::ULT,
2146            next_index,
2147            thread_end.unwrap_or(end_index),
2148            name,
2149        )?;
2150        let post_block = self.context.append_basic_block(self.fn_value(), name);
2151        self.builder
2152            .build_conditional_branch(loop_while, loop_block, post_block)?;
2153        self.builder.position_at_end(post_block);
2154
2155        if self.threaded {
2156            self.jit_end_threading(
2157                thread_start.unwrap(),
2158                thread_end.unwrap(),
2159                test_done.unwrap(),
2160                next.unwrap(),
2161            )?;
2162        }
2163
2164        Ok(())
2165    }
2166
2167    // for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index
2168    fn jit_compile_diagonal_block(
2169        &mut self,
2170        name: &str,
2171        elmt: &TensorBlock,
2172        translation: &Translation,
2173    ) -> Result<()> {
2174        let int_type = self.int_type;
2175
2176        let start_index = int_type.const_int(0, false);
2177        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2178
2179        let (thread_start, thread_end, test_done, next) = if self.threaded {
2180            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2181            (Some(start), Some(end), Some(test_done), Some(next))
2182        } else {
2183            (None, None, None, None)
2184        };
2185
2186        // loop through the non-zero elements
2187        let preblock = self.builder.get_insert_block().unwrap();
2188        let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
2189        self.builder.build_unconditional_branch(start_loop_block)?;
2190        self.builder.position_at_end(start_loop_block);
2191
2192        let curr_index = self.builder.build_phi(int_type, "i")?;
2193        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2194
2195        // loop body - index is just the same for each element
2196        let elmt_index = curr_index.as_basic_value().into_int_value();
2197        let indices_int: Vec<IntValue> =
2198            (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2199
2200        // loop body - eval expression
2201        let float_value =
2202            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2203
2204        // loop body - store result
2205        self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2206
2207        let end_loop_block = self.builder.get_insert_block().unwrap();
2208
2209        // increment loop index
2210        let one = int_type.const_int(1, false);
2211        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2212        curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2213
2214        // loop condition
2215        let loop_while = self.builder.build_int_compare(
2216            IntPredicate::ULT,
2217            next_index,
2218            thread_end.unwrap_or(end_index),
2219            name,
2220        )?;
2221        let post_block = self.context.append_basic_block(self.fn_value(), name);
2222        self.builder
2223            .build_conditional_branch(loop_while, start_loop_block, post_block)?;
2224        self.builder.position_at_end(post_block);
2225
2226        if self.threaded {
2227            self.jit_end_threading(
2228                thread_start.unwrap(),
2229                thread_end.unwrap(),
2230                test_done.unwrap(),
2231                next.unwrap(),
2232            )?;
2233        }
2234
2235        Ok(())
2236    }
2237
2238    fn jit_compile_broadcast_and_store(
2239        &mut self,
2240        name: &str,
2241        elmt: &TensorBlock,
2242        float_value: FloatValue<'ctx>,
2243        expr_index: IntValue<'ctx>,
2244        translation: &Translation,
2245    ) -> Result<()> {
2246        let int_type = self.int_type;
2247        let one = int_type.const_int(1, false);
2248        let zero = int_type.const_int(0, false);
2249        let pre_block = self.builder.get_insert_block().unwrap();
2250        match translation.source {
2251            TranslationFrom::Broadcast {
2252                broadcast_by: _,
2253                broadcast_len,
2254            } => {
2255                let bcast_start_index = zero;
2256                let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2257
2258                // setup loop block
2259                let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2260                self.builder.build_unconditional_branch(bcast_block)?;
2261                self.builder.position_at_end(bcast_block);
2262                let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2263                bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2264
2265                // store value
2266                let store_index = self.builder.build_int_add(
2267                    self.builder
2268                        .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2269                    bcast_index.as_basic_value().into_int_value(),
2270                    "bcast_store_index",
2271                )?;
2272                self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2273
2274                // increment index
2275                let bcast_next_index = self.builder.build_int_add(
2276                    bcast_index.as_basic_value().into_int_value(),
2277                    one,
2278                    name,
2279                )?;
2280                bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2281
2282                // loop condition
2283                let bcast_cond = self.builder.build_int_compare(
2284                    IntPredicate::ULT,
2285                    bcast_next_index,
2286                    bcast_end_index,
2287                    "broadcast_cond",
2288                )?;
2289                let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2290                self.builder
2291                    .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2292                self.builder.position_at_end(post_bcast_block);
2293                Ok(())
2294            }
2295            TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2296                self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2297                Ok(())
2298            }
2299            _ => Err(anyhow!("Invalid translation")),
2300        }
2301    }
2302
2303    fn jit_compile_store(
2304        &mut self,
2305        name: &str,
2306        elmt: &TensorBlock,
2307        store_index: IntValue<'ctx>,
2308        float_value: FloatValue<'ctx>,
2309        translation: &Translation,
2310    ) -> Result<()> {
2311        let int_type = self.int_type;
2312        let res_index = match &translation.target {
2313            TranslationTo::Contiguous { start, end: _ } => {
2314                let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2315                self.builder.build_int_add(start_const, store_index, name)?
2316            }
2317            TranslationTo::Sparse { indices: _ } => {
2318                // load store index from layout
2319                let translate_index = self
2320                    .layout
2321                    .get_translation_index(elmt.expr_layout(), elmt.layout())
2322                    .unwrap();
2323                let translate_store_index =
2324                    translate_index + translation.get_to_index_in_data_layout();
2325                let translate_store_index =
2326                    int_type.const_int(translate_store_index.try_into().unwrap(), false);
2327                let elmt_index_strided = store_index;
2328                let curr_index =
2329                    self.builder
2330                        .build_int_add(elmt_index_strided, translate_store_index, name)?;
2331                let ptr = Self::get_ptr_to_index(
2332                    &self.builder,
2333                    self.int_type,
2334                    self.get_param("indices"),
2335                    curr_index,
2336                    name,
2337                );
2338                self.build_load(self.int_type, ptr, name)?.into_int_value()
2339            }
2340        };
2341
2342        let resi_ptr = Self::get_ptr_to_index(
2343            &self.builder,
2344            self.real_type,
2345            &self.tensor_ptr(),
2346            res_index,
2347            name,
2348        );
2349        self.builder.build_store(resi_ptr, float_value)?;
2350        Ok(())
2351    }
2352
2353    fn jit_compile_expr(
2354        &mut self,
2355        name: &str,
2356        expr: &Ast,
2357        index: &[IntValue<'ctx>],
2358        elmt: &TensorBlock,
2359        expr_index: IntValue<'ctx>,
2360    ) -> Result<FloatValue<'ctx>> {
2361        let name = elmt.name().unwrap_or(name);
2362        match &expr.kind {
2363            AstKind::Binop(binop) => {
2364                let lhs =
2365                    self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2366                let rhs =
2367                    self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2368                match binop.op {
2369                    '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2370                    '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2371                    '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2372                    '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2373                    unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2374                }
2375            }
2376            AstKind::Monop(monop) => {
2377                let child =
2378                    self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2379                match monop.op {
2380                    '-' => Ok(self.builder.build_float_neg(child, name)?),
2381                    unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2382                }
2383            }
2384            AstKind::Call(call) => match self.get_function(call.fn_name) {
2385                Some(function) => {
2386                    let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2387                    for arg in call.args.iter() {
2388                        let arg_val =
2389                            self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2390                        args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2391                    }
2392                    let ret_value = self
2393                        .builder
2394                        .build_call(function, args.as_slice(), name)?
2395                        .try_as_basic_value()
2396                        .unwrap_basic()
2397                        .into_float_value();
2398                    Ok(ret_value)
2399                }
2400                None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2401            },
2402            AstKind::CallArg(arg) => {
2403                self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2404            }
2405            AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2406            AstKind::Name(iname) => {
2407                let ptr = self.get_param(iname.name);
2408                let layout = self.layout.get_layout(iname.name).unwrap();
2409                let iname_elmt_index = if layout.is_dense() {
2410                    // permute indices based on the index chars of this tensor
2411                    let mut no_transform = true;
2412                    let mut iname_index = Vec::new();
2413                    for (i, c) in iname.indices.iter().enumerate() {
2414                        // find the position index of this index char in the tensor's index chars,
2415                        // if it's not found then it must be a contraction index so is at the end
2416                        let pi = elmt
2417                            .indices()
2418                            .iter()
2419                            .position(|x| x == c)
2420                            .unwrap_or(elmt.indices().len());
2421                        // if we are indexing, add the start indice to index[pi]
2422                        if let Some(indice) =
2423                            iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap())
2424                        {
2425                            let start = indice.first.as_ref().kind.as_integer().unwrap();
2426                            let start_intval = self
2427                                .context
2428                                .i32_type()
2429                                .const_int(start.try_into().unwrap(), false);
2430                            // if we are indexing a single element, the index may be out of bounds
2431                            let index_pi = if pi >= index.len() {
2432                                self.context.i32_type().const_int(0, false)
2433                            } else {
2434                                index[pi]
2435                            };
2436                            let index_pi =
2437                                self.builder.build_int_add(index_pi, start_intval, name)?;
2438                            iname_index.push(index_pi);
2439                        } else {
2440                            iname_index.push(index[pi]);
2441                        }
2442                        no_transform = no_transform && pi == i;
2443                    }
2444                    // calculate the element index using iname_index and the shape of the tensor
2445                    // TODO: can we optimise this by using expr_index, and also including elmt_index?
2446                    if !iname_index.is_empty() {
2447                        let mut iname_elmt_index = *iname_index.last().unwrap();
2448                        let mut stride = 1u64;
2449                        for i in (0..iname_index.len() - 1).rev() {
2450                            let iname_i = iname_index[i];
2451                            let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2452                            stride *= shapei;
2453                            let stride_intval = self.context.i32_type().const_int(stride, false);
2454                            let stride_mul_i =
2455                                self.builder.build_int_mul(stride_intval, iname_i, name)?;
2456                            iname_elmt_index =
2457                                self.builder
2458                                    .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2459                        }
2460                        iname_elmt_index
2461                    } else {
2462                        // zero if we are not indexing, otherwise use the start value of indice
2463                        let zero = self.context.i32_type().const_int(0, false);
2464                        zero
2465                    }
2466                } else if layout.is_sparse() || layout.is_diagonal() {
2467                    let expr_layout = elmt.expr_layout();
2468
2469                    if expr_layout != layout {
2470                        // get correct index from binary layout map, ie. indices[ binary_layout_index + expr_index ]
2471                        // if its a -1 then return a 0
2472                        // ie. expr_index = binary_layout[expr_index]
2473                        //.    if expr_index == -1 then return 0 as the value of the expression
2474                        //.    otherwise load the value at that index
2475                        // we are doing an if statement so I think we need to return early here
2476                        let permutation =
2477                            DataLayout::permutation(elmt, iname.indices.as_slice(), layout);
2478                        if let Some(base_binary_layout_index) =
2479                            self.layout
2480                                .get_binary_layout_index(layout, expr_layout, permutation)
2481                        {
2482                            let binary_layout_index = self.builder.build_int_add(
2483                                self.int_type
2484                                    .const_int(base_binary_layout_index.try_into().unwrap(), false),
2485                                expr_index,
2486                                name,
2487                            )?;
2488
2489                            let indices_ptr = Self::get_ptr_to_index(
2490                                &self.builder,
2491                                self.int_type,
2492                                self.get_param("indices"),
2493                                binary_layout_index,
2494                                name,
2495                            );
2496                            let mapped_index = self
2497                                .build_load(self.int_type, indices_ptr, name)?
2498                                .into_int_value();
2499                            let is_less_than_zero = self.builder.build_int_compare(
2500                                IntPredicate::SLT,
2501                                mapped_index,
2502                                self.int_type.const_int(0, true),
2503                                "sparse_index_check",
2504                            )?;
2505                            let is_less_than_zero_block =
2506                                self.context.append_basic_block(self.fn_value(), "lt_zero");
2507                            let not_less_than_zero_block = self
2508                                .context
2509                                .append_basic_block(self.fn_value(), "not_lt_zero");
2510                            let merge_block =
2511                                self.context.append_basic_block(self.fn_value(), "merge");
2512                            self.builder.build_conditional_branch(
2513                                is_less_than_zero,
2514                                is_less_than_zero_block,
2515                                not_less_than_zero_block,
2516                            )?;
2517
2518                            // if mapped index < 0 return 0
2519                            self.builder.position_at_end(is_less_than_zero_block);
2520                            let zero_value = self.real_type.const_float(0.);
2521                            self.builder.build_unconditional_branch(merge_block)?;
2522
2523                            // if mapped index >=0 load value at that index
2524                            self.builder.position_at_end(not_less_than_zero_block);
2525                            let value_ptr = Self::get_ptr_to_index(
2526                                &self.builder,
2527                                self.real_type,
2528                                ptr,
2529                                mapped_index,
2530                                name,
2531                            );
2532                            let value = self
2533                                .build_load(self.real_type, value_ptr, name)?
2534                                .into_float_value();
2535                            self.builder.build_unconditional_branch(merge_block)?;
2536
2537                            // return value or 0 from if statement
2538                            self.builder.position_at_end(merge_block);
2539                            let if_return_value =
2540                                self.builder.build_phi(self.real_type, "sparse_value")?;
2541                            if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
2542                            if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
2543
2544                            let phi_value = if_return_value.as_basic_value().into_float_value();
2545                            return Ok(phi_value);
2546                        } else {
2547                            expr_index
2548                        }
2549                    } else {
2550                        // we can just use the elmt_index since the layouts are the same
2551                        expr_index
2552                    }
2553                } else {
2554                    panic!("unexpected layout");
2555                };
2556                let value_ptr = Self::get_ptr_to_index(
2557                    &self.builder,
2558                    self.real_type,
2559                    ptr,
2560                    iname_elmt_index,
2561                    name,
2562                );
2563                Ok(self
2564                    .build_load(self.real_type, value_ptr, name)?
2565                    .into_float_value())
2566            }
2567            AstKind::NamedGradient(name) => {
2568                let name_str = name.to_string();
2569                let ptr = self.get_param(name_str.as_str());
2570                Ok(self
2571                    .build_load(self.real_type, *ptr, name_str.as_str())?
2572                    .into_float_value())
2573            }
2574            AstKind::Index(_) => todo!(),
2575            AstKind::Slice(_) => todo!(),
2576            AstKind::Integer(_) => todo!(),
2577            _ => panic!("unexprected astkind"),
2578        }
2579    }
2580
2581    fn clear(&mut self) {
2582        self.variables.clear();
2583        //self.functions.clear();
2584        self.fn_value_opt = None;
2585        self.tensor_ptr_opt = None;
2586    }
2587
2588    fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2589        match arg {
2590            BasicValueEnum::PointerValue(v) => v,
2591            BasicValueEnum::FloatValue(v) => {
2592                let alloca = self
2593                    .create_entry_block_builder()
2594                    .build_alloca(arg.get_type(), name)
2595                    .unwrap();
2596                self.builder.build_store(alloca, v).unwrap();
2597                alloca
2598            }
2599            BasicValueEnum::IntValue(v) => {
2600                let alloca = self
2601                    .create_entry_block_builder()
2602                    .build_alloca(arg.get_type(), name)
2603                    .unwrap();
2604                self.builder.build_store(alloca, v).unwrap();
2605                alloca
2606            }
2607            _ => unreachable!(),
2608        }
2609    }
2610
2611    pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2612        self.clear();
2613        let void_type = self.context.void_type();
2614        let fn_type = void_type.fn_type(
2615            &[
2616                self.real_ptr_type.into(),
2617                self.real_ptr_type.into(),
2618                self.int_type.into(),
2619                self.int_type.into(),
2620            ],
2621            false,
2622        );
2623        let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2624        let function = self.module.add_function("set_u0", fn_type, None);
2625
2626        // add noalias
2627        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2628        let noalign = self.context.create_enum_attribute(alias_id, 0);
2629        for i in &[0, 1] {
2630            function.add_attribute(AttributeLoc::Param(*i), noalign);
2631        }
2632
2633        let basic_block = self.context.append_basic_block(function, "entry");
2634        self.fn_value_opt = Some(function);
2635        self.builder.position_at_end(basic_block);
2636
2637        for (i, arg) in function.get_param_iter().enumerate() {
2638            let name = fn_arg_names[i];
2639            let alloca = self.function_arg_alloca(name, arg);
2640            self.insert_param(name, alloca);
2641        }
2642
2643        self.insert_data(model);
2644        self.insert_indices();
2645
2646        let mut nbarriers = 0;
2647        let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2648        #[allow(clippy::explicit_counter_loop)]
2649        for a in model.input_dep_defns() {
2650            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2651            self.jit_compile_call_barrier(nbarriers, total_barriers);
2652            nbarriers += 1;
2653        }
2654
2655        self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?;
2656        self.jit_compile_call_barrier(nbarriers, total_barriers);
2657
2658        self.builder.build_return(None)?;
2659
2660        if function.verify(true) {
2661            Ok(function)
2662        } else {
2663            function.print_to_stderr();
2664            unsafe {
2665                function.delete();
2666            }
2667            Err(anyhow!("Invalid generated function."))
2668        }
2669    }
2670
2671    pub fn compile_calc_out<'m>(
2672        &mut self,
2673        model: &'m DiscreteModel,
2674        include_constants: bool,
2675    ) -> Result<FunctionValue<'ctx>> {
2676        self.clear();
2677        let void_type = self.context.void_type();
2678        let fn_type = void_type.fn_type(
2679            &[
2680                self.real_type.into(),
2681                self.real_ptr_type.into(),
2682                self.real_ptr_type.into(),
2683                self.real_ptr_type.into(),
2684                self.int_type.into(),
2685                self.int_type.into(),
2686            ],
2687            false,
2688        );
2689        let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2690        let function_name = if include_constants {
2691            "calc_out_full"
2692        } else {
2693            "calc_out"
2694        };
2695        let function = self.module.add_function(function_name, fn_type, None);
2696
2697        // add noalias
2698        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2699        let noalign = self.context.create_enum_attribute(alias_id, 0);
2700        for i in &[1, 2] {
2701            function.add_attribute(AttributeLoc::Param(*i), noalign);
2702        }
2703
2704        let basic_block = self.context.append_basic_block(function, "entry");
2705        self.fn_value_opt = Some(function);
2706        self.builder.position_at_end(basic_block);
2707
2708        for (i, arg) in function.get_param_iter().enumerate() {
2709            let name = fn_arg_names[i];
2710            let alloca = self.function_arg_alloca(name, arg);
2711            self.insert_param(name, alloca);
2712        }
2713
2714        self.insert_state(model.state());
2715        self.insert_data(model);
2716        self.insert_indices();
2717
2718        // print thread_id and thread_dim
2719        //let thread_id = function.get_nth_param(3).unwrap();
2720        //let thread_dim = function.get_nth_param(4).unwrap();
2721        //self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?;
2722        //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?;
2723        if let Some(out) = model.out() {
2724            let mut nbarriers = 0;
2725            let mut total_barriers =
2726                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2727            if include_constants {
2728                total_barriers += model.input_dep_defns().len() as u64;
2729                // calculate time independant definitions
2730                for tensor in model.input_dep_defns() {
2731                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2732                    self.jit_compile_call_barrier(nbarriers, total_barriers);
2733                    nbarriers += 1;
2734                }
2735            }
2736
2737            // calculate time dependant definitions
2738            for tensor in model.time_dep_defns() {
2739                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2740                self.jit_compile_call_barrier(nbarriers, total_barriers);
2741                nbarriers += 1;
2742            }
2743
2744            // calculate state dependant definitions
2745            #[allow(clippy::explicit_counter_loop)]
2746            for a in model.state_dep_defns() {
2747                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2748                self.jit_compile_call_barrier(nbarriers, total_barriers);
2749                nbarriers += 1;
2750            }
2751
2752            self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())))?;
2753            self.jit_compile_call_barrier(nbarriers, total_barriers);
2754        }
2755        self.builder.build_return(None)?;
2756
2757        if function.verify(true) {
2758            Ok(function)
2759        } else {
2760            function.print_to_stderr();
2761            unsafe {
2762                function.delete();
2763            }
2764            Err(anyhow!("Invalid generated function."))
2765        }
2766    }
2767
2768    pub fn compile_calc_stop<'m>(
2769        &mut self,
2770        model: &'m DiscreteModel,
2771    ) -> Result<FunctionValue<'ctx>> {
2772        self.clear();
2773        let void_type = self.context.void_type();
2774        let fn_type = void_type.fn_type(
2775            &[
2776                self.real_type.into(),
2777                self.real_ptr_type.into(),
2778                self.real_ptr_type.into(),
2779                self.real_ptr_type.into(),
2780                self.int_type.into(),
2781                self.int_type.into(),
2782            ],
2783            false,
2784        );
2785        let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2786        let function = self.module.add_function("calc_stop", fn_type, None);
2787
2788        // add noalias
2789        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2790        let noalign = self.context.create_enum_attribute(alias_id, 0);
2791        for i in &[1, 2, 3] {
2792            function.add_attribute(AttributeLoc::Param(*i), noalign);
2793        }
2794
2795        let basic_block = self.context.append_basic_block(function, "entry");
2796        self.fn_value_opt = Some(function);
2797        self.builder.position_at_end(basic_block);
2798
2799        for (i, arg) in function.get_param_iter().enumerate() {
2800            let name = fn_arg_names[i];
2801            let alloca = self.function_arg_alloca(name, arg);
2802            self.insert_param(name, alloca);
2803        }
2804
2805        self.insert_state(model.state());
2806        self.insert_data(model);
2807        self.insert_indices();
2808
2809        if let Some(stop) = model.stop() {
2810            // calculate time dependant definitions
2811            let mut nbarriers = 0;
2812            let total_barriers =
2813                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2814            for tensor in model.time_dep_defns() {
2815                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2816                self.jit_compile_call_barrier(nbarriers, total_barriers);
2817                nbarriers += 1;
2818            }
2819
2820            // calculate state dependant definitions
2821            for a in model.state_dep_defns() {
2822                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2823                self.jit_compile_call_barrier(nbarriers, total_barriers);
2824                nbarriers += 1;
2825            }
2826
2827            let res_ptr = self.get_param("root");
2828            self.jit_compile_tensor(stop, Some(*res_ptr))?;
2829            self.jit_compile_call_barrier(nbarriers, total_barriers);
2830        }
2831        self.builder.build_return(None)?;
2832
2833        if function.verify(true) {
2834            Ok(function)
2835        } else {
2836            function.print_to_stderr();
2837            unsafe {
2838                function.delete();
2839            }
2840            Err(anyhow!("Invalid generated function."))
2841        }
2842    }
2843
2844    pub fn compile_rhs<'m>(
2845        &mut self,
2846        model: &'m DiscreteModel,
2847        include_constants: bool,
2848    ) -> Result<FunctionValue<'ctx>> {
2849        self.clear();
2850        let void_type = self.context.void_type();
2851        let fn_type = void_type.fn_type(
2852            &[
2853                self.real_type.into(),
2854                self.real_ptr_type.into(),
2855                self.real_ptr_type.into(),
2856                self.real_ptr_type.into(),
2857                self.int_type.into(),
2858                self.int_type.into(),
2859            ],
2860            false,
2861        );
2862        let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
2863        let function_name = if include_constants { "rhs_full" } else { "rhs" };
2864        let function = self.module.add_function(function_name, fn_type, None);
2865
2866        // add noalias
2867        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2868        let noalign = self.context.create_enum_attribute(alias_id, 0);
2869        for i in &[1, 2, 3] {
2870            function.add_attribute(AttributeLoc::Param(*i), noalign);
2871        }
2872
2873        let basic_block = self.context.append_basic_block(function, "entry");
2874        self.fn_value_opt = Some(function);
2875        self.builder.position_at_end(basic_block);
2876
2877        for (i, arg) in function.get_param_iter().enumerate() {
2878            let name = fn_arg_names[i];
2879            let alloca = self.function_arg_alloca(name, arg);
2880            self.insert_param(name, alloca);
2881        }
2882
2883        self.insert_state(model.state());
2884        self.insert_data(model);
2885        self.insert_indices();
2886
2887        let mut nbarriers = 0;
2888        let mut total_barriers =
2889            (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2890        if include_constants {
2891            total_barriers += model.input_dep_defns().len() as u64;
2892            // calculate constant definitions
2893            for tensor in model.input_dep_defns() {
2894                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2895                self.jit_compile_call_barrier(nbarriers, total_barriers);
2896                nbarriers += 1;
2897            }
2898        }
2899
2900        // calculate time dependant definitions
2901        for tensor in model.time_dep_defns() {
2902            self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2903            self.jit_compile_call_barrier(nbarriers, total_barriers);
2904            nbarriers += 1;
2905        }
2906
2907        // TODO: could split state dep defns into before and after F
2908        for a in model.state_dep_defns() {
2909            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2910            self.jit_compile_call_barrier(nbarriers, total_barriers);
2911            nbarriers += 1;
2912        }
2913
2914        // F
2915        let res_ptr = self.get_param("rr");
2916        self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?;
2917        self.jit_compile_call_barrier(nbarriers, total_barriers);
2918
2919        self.builder.build_return(None)?;
2920
2921        if function.verify(true) {
2922            Ok(function)
2923        } else {
2924            function.print_to_stderr();
2925            unsafe {
2926                function.delete();
2927            }
2928            Err(anyhow!("Invalid generated function."))
2929        }
2930    }
2931
2932    pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2933        self.clear();
2934        let void_type = self.context.void_type();
2935        let fn_type = void_type.fn_type(
2936            &[
2937                self.real_type.into(),
2938                self.real_ptr_type.into(),
2939                self.real_ptr_type.into(),
2940                self.real_ptr_type.into(),
2941                self.int_type.into(),
2942                self.int_type.into(),
2943            ],
2944            false,
2945        );
2946        let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
2947        let function = self.module.add_function("mass", fn_type, None);
2948
2949        // add noalias
2950        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2951        let noalign = self.context.create_enum_attribute(alias_id, 0);
2952        for i in &[1, 2, 3] {
2953            function.add_attribute(AttributeLoc::Param(*i), noalign);
2954        }
2955
2956        let basic_block = self.context.append_basic_block(function, "entry");
2957        self.fn_value_opt = Some(function);
2958        self.builder.position_at_end(basic_block);
2959
2960        for (i, arg) in function.get_param_iter().enumerate() {
2961            let name = fn_arg_names[i];
2962            let alloca = self.function_arg_alloca(name, arg);
2963            self.insert_param(name, alloca);
2964        }
2965
2966        // only put code in this function if we have a state_dot and lhs
2967        if model.state_dot().is_some() && model.lhs().is_some() {
2968            let state_dot = model.state_dot().unwrap();
2969            let lhs = model.lhs().unwrap();
2970
2971            self.insert_dot_state(state_dot);
2972            self.insert_data(model);
2973            self.insert_indices();
2974
2975            // calculate time dependant definitions
2976            let mut nbarriers = 0;
2977            let total_barriers =
2978                (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
2979            for tensor in model.time_dep_defns() {
2980                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2981                self.jit_compile_call_barrier(nbarriers, total_barriers);
2982                nbarriers += 1;
2983            }
2984
2985            for a in model.dstate_dep_defns() {
2986                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2987                self.jit_compile_call_barrier(nbarriers, total_barriers);
2988                nbarriers += 1;
2989            }
2990
2991            // mass
2992            let res_ptr = self.get_param("rr");
2993            self.jit_compile_tensor(lhs, Some(*res_ptr))?;
2994            self.jit_compile_call_barrier(nbarriers, total_barriers);
2995        }
2996
2997        self.builder.build_return(None)?;
2998
2999        if function.verify(true) {
3000            Ok(function)
3001        } else {
3002            function.print_to_stderr();
3003            unsafe {
3004                function.delete();
3005            }
3006            Err(anyhow!("Invalid generated function."))
3007        }
3008    }
3009
3010    pub fn compile_gradient(
3011        &mut self,
3012        original_function: FunctionValue<'ctx>,
3013        args_type: &[CompileGradientArgType],
3014        mode: CompileMode,
3015        fn_name: &str,
3016    ) -> Result<FunctionValue<'ctx>> {
3017        self.clear();
3018
3019        // construct the gradient function
3020        let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
3021
3022        let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
3023
3024        let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
3025        let mut start_param_index: Vec<u32> = Vec::new();
3026        let mut ptr_arg_indices: Vec<u32> = Vec::new();
3027        for (i, arg) in original_function.get_param_iter().enumerate() {
3028            start_param_index.push(u32::try_from(fn_type.len()).unwrap());
3029            let arg_type = arg.get_type();
3030            fn_type.push(arg_type.into());
3031
3032            // constant args with type T in the original funciton have 2 args of type [int, T]
3033            enzyme_fn_type.push(self.int_type.into());
3034            enzyme_fn_type.push(arg.get_type().into());
3035
3036            if arg_type.is_pointer_type() {
3037                ptr_arg_indices.push(u32::try_from(i).unwrap());
3038            }
3039
3040            match args_type[i] {
3041                CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
3042                    fn_type.push(arg.get_type().into());
3043                    enzyme_fn_type.push(arg.get_type().into());
3044                }
3045                CompileGradientArgType::Const => {}
3046            }
3047        }
3048        let void_type = self.context.void_type();
3049        let fn_type = void_type.fn_type(fn_type.as_slice(), false);
3050        let function = self.module.add_function(fn_name, fn_type, None);
3051
3052        // add noalias
3053        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3054        let noalign = self.context.create_enum_attribute(alias_id, 0);
3055        for i in ptr_arg_indices {
3056            function.add_attribute(AttributeLoc::Param(i), noalign);
3057        }
3058
3059        let basic_block = self.context.append_basic_block(function, "entry");
3060        self.fn_value_opt = Some(function);
3061        self.builder.position_at_end(basic_block);
3062
3063        let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
3064        let mut input_activity = Vec::new();
3065        let mut arg_trees = Vec::new();
3066        for (i, arg) in original_function.get_param_iter().enumerate() {
3067            let param_index = start_param_index[i];
3068            let fn_arg = function.get_nth_param(param_index).unwrap();
3069
3070            // we'll probably only get double or pointers to doubles, so let assume this for now
3071            // todo: perhaps refactor this into a recursive function, might be overkill
3072            let concrete_type = match arg.get_type() {
3073                BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
3074                BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
3075                    RealType::F32 => CConcreteType_DT_Float,
3076                    RealType::F64 => CConcreteType_DT_Double,
3077                },
3078                BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
3079                _ => panic!("unsupported type"),
3080            };
3081            let new_tree = unsafe {
3082                EnzymeNewTypeTreeCT(
3083                    concrete_type,
3084                    self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3085                )
3086            };
3087            unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3088
3089            // pointer to real type
3090            if concrete_type == CConcreteType_DT_Pointer {
3091                let inner_concrete_type = match self.diffsl_real_type {
3092                    RealType::F32 => CConcreteType_DT_Float,
3093                    RealType::F64 => CConcreteType_DT_Double,
3094                };
3095                let inner_new_tree = unsafe {
3096                    EnzymeNewTypeTreeCT(
3097                        inner_concrete_type,
3098                        self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3099                    )
3100                };
3101                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3102                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3103                unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
3104            }
3105
3106            arg_trees.push(new_tree);
3107            match args_type[i] {
3108                CompileGradientArgType::Dup => {
3109                    // pass in the arg value
3110                    enzyme_fn_args.push(fn_arg.into());
3111
3112                    // pass in the darg value
3113                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3114                    enzyme_fn_args.push(fn_darg.into());
3115
3116                    input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3117                }
3118                CompileGradientArgType::DupNoNeed => {
3119                    // pass in the arg value
3120                    enzyme_fn_args.push(fn_arg.into());
3121
3122                    // pass in the darg value
3123                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3124                    enzyme_fn_args.push(fn_darg.into());
3125
3126                    input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3127                }
3128                CompileGradientArgType::Const => {
3129                    // pass in the arg value
3130                    enzyme_fn_args.push(fn_arg.into());
3131
3132                    input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3133                }
3134            }
3135        }
3136        // if we have void ret, this must be false;
3137        let ret_primary_ret = false;
3138        let diff_ret = false;
3139        let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3140        let ret_tree = unsafe {
3141            EnzymeNewTypeTreeCT(
3142                CConcreteType_DT_Anything,
3143                self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3144            )
3145        };
3146
3147        // always optimize
3148        let fnc_opt_base = true;
3149        let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3150
3151        let kv_tmp = IntList {
3152            data: std::ptr::null_mut(),
3153            size: 0,
3154        };
3155        let mut known_values = vec![kv_tmp; input_activity.len()];
3156
3157        let fn_type_info = CFnTypeInfo {
3158            Arguments: arg_trees.as_mut_ptr(),
3159            Return: ret_tree,
3160            KnownValues: known_values.as_mut_ptr(),
3161        };
3162
3163        let type_analysis: EnzymeTypeAnalysisRef =
3164            unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3165
3166        let mut args_uncacheable = vec![0; arg_trees.len()];
3167
3168        let enzyme_function = match mode {
3169            CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3170                EnzymeCreateForwardDiff(
3171                    logic_ref, // Logic
3172                    std::ptr::null_mut(),
3173                    std::ptr::null_mut(),
3174                    original_function.as_value_ref(),
3175                    ret_activity, // LLVM function, return type
3176                    input_activity.as_mut_ptr(),
3177                    input_activity.len(), // constant arguments
3178                    type_analysis,        // type analysis struct
3179                    ret_primary_ret as u8,
3180                    CDerivativeMode_DEM_ForwardMode, // return value, dret_used, top_level which was 1
3181                    1,                               // free memory
3182                    0,                               // runtime activity
3183                    0,                               // strong zero
3184                    1,                               // vector mode width
3185                    std::ptr::null_mut(),            // additional argument
3186                    fn_type_info,                    // additional_arg, type info (return + args)
3187                    1,                               // subsequent calls may write
3188                    args_uncacheable.as_mut_ptr(),   // overwritten args
3189                    args_uncacheable.len(),          // overwritten args length
3190                    std::ptr::null_mut(),            // write augmented function to this
3191                )
3192            },
3193            CompileMode::Reverse | CompileMode::ReverseSens => {
3194                let mut call_enzyme = || unsafe {
3195                    EnzymeCreatePrimalAndGradient(
3196                        logic_ref,
3197                        std::ptr::null_mut(),
3198                        std::ptr::null_mut(),
3199                        original_function.as_value_ref(),
3200                        ret_activity,
3201                        input_activity.as_mut_ptr(),
3202                        input_activity.len(),
3203                        type_analysis,
3204                        ret_primary_ret as u8,
3205                        diff_ret as u8,
3206                        CDerivativeMode_DEM_ReverseModeCombined,
3207                        0,
3208                        0, // strong zero
3209                        1,
3210                        1,
3211                        std::ptr::null_mut(),
3212                        0, // force annonymous tape
3213                        fn_type_info,
3214                        0, // subsequent calls may write
3215                        args_uncacheable.as_mut_ptr(),
3216                        args_uncacheable.len(),
3217                        std::ptr::null_mut(),
3218                        if self.threaded { 1 } else { 0 }, // atomic add
3219                    )
3220                };
3221                if self.threaded {
3222                    // the register call handler alters a global variable, so we need to lock it
3223                    let _lock = my_mutex.lock().unwrap();
3224                    let barrier_string = CString::new("barrier").unwrap();
3225                    unsafe {
3226                        EnzymeRegisterCallHandler(
3227                            barrier_string.as_ptr(),
3228                            Some(fwd_handler),
3229                            Some(rev_handler),
3230                        )
3231                    };
3232                    let ret = call_enzyme();
3233                    // unregister it so some other thread doesn't use it
3234                    unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3235                    ret
3236                } else {
3237                    call_enzyme()
3238                }
3239            }
3240        };
3241
3242        // free everything
3243        unsafe { FreeEnzymeLogic(logic_ref) };
3244        unsafe { FreeTypeAnalysis(type_analysis) };
3245        unsafe { EnzymeFreeTypeTree(ret_tree) };
3246        for tree in arg_trees {
3247            unsafe { EnzymeFreeTypeTree(tree) };
3248        }
3249
3250        // call enzyme function
3251        let enzyme_function =
3252            unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3253        self.builder
3254            .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3255
3256        // return
3257        self.builder.build_return(None)?;
3258
3259        if function.verify(true) {
3260            Ok(function)
3261        } else {
3262            function.print_to_stderr();
3263            enzyme_function.print_to_stderr();
3264            unsafe {
3265                function.delete();
3266            }
3267            Err(anyhow!("Invalid generated function."))
3268        }
3269    }
3270
3271    pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3272        self.clear();
3273        let fn_type = self.context.void_type().fn_type(
3274            &[
3275                self.int_ptr_type.into(),
3276                self.int_ptr_type.into(),
3277                self.int_ptr_type.into(),
3278                self.int_ptr_type.into(),
3279                self.int_ptr_type.into(),
3280                self.int_ptr_type.into(),
3281            ],
3282            false,
3283        );
3284
3285        let function = self.module.add_function("get_dims", fn_type, None);
3286        let block = self.context.append_basic_block(function, "entry");
3287        let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3288        self.builder.position_at_end(block);
3289
3290        for (i, arg) in function.get_param_iter().enumerate() {
3291            let name = fn_arg_names[i];
3292            let alloca = self.function_arg_alloca(name, arg);
3293            self.insert_param(name, alloca);
3294        }
3295
3296        self.insert_indices();
3297
3298        let number_of_states = model.state().nnz() as u64;
3299        let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
3300        let number_of_outputs = match model.out() {
3301            Some(out) => out.nnz() as u64,
3302            None => 0,
3303        };
3304        let number_of_stop = if let Some(stop) = model.stop() {
3305            stop.nnz() as u64
3306        } else {
3307            0
3308        };
3309        let has_mass = match model.lhs().is_some() {
3310            true => 1u64,
3311            false => 0u64,
3312        };
3313        let data_len = self.layout.data().len() as u64;
3314        self.builder.build_store(
3315            *self.get_param("states"),
3316            self.int_type.const_int(number_of_states, false),
3317        )?;
3318        self.builder.build_store(
3319            *self.get_param("inputs"),
3320            self.int_type.const_int(number_of_inputs, false),
3321        )?;
3322        self.builder.build_store(
3323            *self.get_param("outputs"),
3324            self.int_type.const_int(number_of_outputs, false),
3325        )?;
3326        self.builder.build_store(
3327            *self.get_param("data"),
3328            self.int_type.const_int(data_len, false),
3329        )?;
3330        self.builder.build_store(
3331            *self.get_param("stop"),
3332            self.int_type.const_int(number_of_stop, false),
3333        )?;
3334        self.builder.build_store(
3335            *self.get_param("has_mass"),
3336            self.int_type.const_int(has_mass, false),
3337        )?;
3338        self.builder.build_return(None)?;
3339
3340        if function.verify(true) {
3341            Ok(function)
3342        } else {
3343            function.print_to_stderr();
3344            unsafe {
3345                function.delete();
3346            }
3347            Err(anyhow!("Invalid generated function."))
3348        }
3349    }
3350
3351    pub fn compile_get_tensor(
3352        &mut self,
3353        model: &DiscreteModel,
3354        name: &str,
3355    ) -> Result<FunctionValue<'ctx>> {
3356        self.clear();
3357        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3358        let fn_type = self.context.void_type().fn_type(
3359            &[
3360                self.real_ptr_type.into(),
3361                real_ptr_ptr_type.into(),
3362                self.int_ptr_type.into(),
3363            ],
3364            false,
3365        );
3366        let function_name = format!("get_tensor_{name}");
3367        let function = self
3368            .module
3369            .add_function(function_name.as_str(), fn_type, None);
3370        let basic_block = self.context.append_basic_block(function, "entry");
3371        self.fn_value_opt = Some(function);
3372
3373        let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3374        self.builder.position_at_end(basic_block);
3375
3376        for (i, arg) in function.get_param_iter().enumerate() {
3377            let name = fn_arg_names[i];
3378            let alloca = self.function_arg_alloca(name, arg);
3379            self.insert_param(name, alloca);
3380        }
3381
3382        self.insert_data(model);
3383        let ptr = self.get_param(name);
3384        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3385        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3386        self.builder
3387            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3388        self.builder
3389            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3390        self.builder.build_return(None)?;
3391
3392        if function.verify(true) {
3393            Ok(function)
3394        } else {
3395            function.print_to_stderr();
3396            unsafe {
3397                function.delete();
3398            }
3399            Err(anyhow!("Invalid generated function."))
3400        }
3401    }
3402
3403    pub fn compile_get_constant(
3404        &mut self,
3405        model: &DiscreteModel,
3406        name: &str,
3407    ) -> Result<FunctionValue<'ctx>> {
3408        self.clear();
3409        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3410        let fn_type = self
3411            .context
3412            .void_type()
3413            .fn_type(&[real_ptr_ptr_type.into(), self.int_ptr_type.into()], false);
3414        let function_name = format!("get_constant_{name}");
3415        let function = self
3416            .module
3417            .add_function(function_name.as_str(), fn_type, None);
3418        let basic_block = self.context.append_basic_block(function, "entry");
3419        self.fn_value_opt = Some(function);
3420
3421        let fn_arg_names = &["tensor_data", "tensor_size"];
3422        self.builder.position_at_end(basic_block);
3423
3424        for (i, arg) in function.get_param_iter().enumerate() {
3425            let name = fn_arg_names[i];
3426            let alloca = self.function_arg_alloca(name, arg);
3427            self.insert_param(name, alloca);
3428        }
3429
3430        self.insert_constants(model);
3431        let ptr = self.get_param(name);
3432        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3433        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3434        self.builder
3435            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3436        self.builder
3437            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3438        self.builder.build_return(None)?;
3439
3440        if function.verify(true) {
3441            Ok(function)
3442        } else {
3443            function.print_to_stderr();
3444            unsafe {
3445                function.delete();
3446            }
3447            Err(anyhow!("Invalid generated function."))
3448        }
3449    }
3450
3451    pub fn compile_inputs(
3452        &mut self,
3453        model: &DiscreteModel,
3454        is_get: bool,
3455    ) -> Result<FunctionValue<'ctx>> {
3456        self.clear();
3457        let void_type = self.context.void_type();
3458        let fn_type = void_type.fn_type(
3459            &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3460            false,
3461        );
3462        let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3463        let function = self.module.add_function(function_name, fn_type, None);
3464        let block = self.context.append_basic_block(function, "entry");
3465        self.fn_value_opt = Some(function);
3466
3467        let fn_arg_names = &["inputs", "data"];
3468        self.builder.position_at_end(block);
3469
3470        for (i, arg) in function.get_param_iter().enumerate() {
3471            let name = fn_arg_names[i];
3472            let alloca = self.function_arg_alloca(name, arg);
3473            self.insert_param(name, alloca);
3474        }
3475
3476        if let Some(input) = model.input() {
3477            let name = input.name();
3478            self.insert_tensor(input, false);
3479            let ptr = self.get_var(input);
3480            // loop thru the elements of this input and set/get them using the inputs ptr
3481            let inputs_start_index = self.int_type.const_int(0, false);
3482            let start_index = self.int_type.const_int(0, false);
3483            let end_index = self
3484                .int_type
3485                .const_int(input.nnz().try_into().unwrap(), false);
3486
3487            let input_block = self.context.append_basic_block(function, name);
3488            self.builder.build_unconditional_branch(input_block)?;
3489            self.builder.position_at_end(input_block);
3490            let index = self.builder.build_phi(self.int_type, "i")?;
3491            index.add_incoming(&[(&start_index, block)]);
3492
3493            // loop body - copy value from inputs to data
3494            let curr_input_index = index.as_basic_value().into_int_value();
3495            let input_ptr =
3496                Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
3497            let curr_inputs_index =
3498                self.builder
3499                    .build_int_add(inputs_start_index, curr_input_index, name)?;
3500            let inputs_ptr = Self::get_ptr_to_index(
3501                &self.builder,
3502                self.real_type,
3503                self.get_param("inputs"),
3504                curr_inputs_index,
3505                name,
3506            );
3507            if is_get {
3508                let input_value = self
3509                    .build_load(self.real_type, input_ptr, name)?
3510                    .into_float_value();
3511                self.builder.build_store(inputs_ptr, input_value)?;
3512            } else {
3513                let input_value = self
3514                    .build_load(self.real_type, inputs_ptr, name)?
3515                    .into_float_value();
3516                self.builder.build_store(input_ptr, input_value)?;
3517            }
3518
3519            // increment loop index
3520            let one = self.int_type.const_int(1, false);
3521            let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
3522            index.add_incoming(&[(&next_index, input_block)]);
3523
3524            // loop condition
3525            let loop_while =
3526                self.builder
3527                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
3528            let post_block = self.context.append_basic_block(function, name);
3529            self.builder
3530                .build_conditional_branch(loop_while, input_block, post_block)?;
3531            self.builder.position_at_end(post_block);
3532        }
3533        self.builder.build_return(None)?;
3534
3535        if function.verify(true) {
3536            Ok(function)
3537        } else {
3538            function.print_to_stderr();
3539            unsafe {
3540                function.delete();
3541            }
3542            Err(anyhow!("Invalid generated function."))
3543        }
3544    }
3545
3546    pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3547        self.clear();
3548        let void_type = self.context.void_type();
3549        let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false);
3550        let function = self.module.add_function("set_id", fn_type, None);
3551        let mut block = self.context.append_basic_block(function, "entry");
3552
3553        let fn_arg_names = &["id"];
3554        self.builder.position_at_end(block);
3555
3556        for (i, arg) in function.get_param_iter().enumerate() {
3557            let name = fn_arg_names[i];
3558            let alloca = self.function_arg_alloca(name, arg);
3559            self.insert_param(name, alloca);
3560        }
3561
3562        let mut id_index = 0usize;
3563        for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3564            let name = blk.name().unwrap_or("unknown");
3565            // loop thru the elements of this state blk and set the corresponding elements of id
3566            let id_start_index = self.int_type.const_int(id_index as u64, false);
3567            let blk_start_index = self.int_type.const_int(0, false);
3568            let blk_end_index = self
3569                .int_type
3570                .const_int(blk.nnz().try_into().unwrap(), false);
3571
3572            let blk_block = self.context.append_basic_block(function, name);
3573            self.builder.build_unconditional_branch(blk_block)?;
3574            self.builder.position_at_end(blk_block);
3575            let index = self.builder.build_phi(self.int_type, "i")?;
3576            index.add_incoming(&[(&blk_start_index, block)]);
3577
3578            // loop body - copy value from inputs to data
3579            let curr_blk_index = index.as_basic_value().into_int_value();
3580            let curr_id_index = self
3581                .builder
3582                .build_int_add(id_start_index, curr_blk_index, name)?;
3583            let id_ptr = Self::get_ptr_to_index(
3584                &self.builder,
3585                self.real_type,
3586                self.get_param("id"),
3587                curr_id_index,
3588                name,
3589            );
3590            let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
3591            let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3592            self.builder.build_store(id_ptr, is_algebraic_value)?;
3593
3594            // increment loop index
3595            let one = self.int_type.const_int(1, false);
3596            let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3597            index.add_incoming(&[(&next_index, blk_block)]);
3598
3599            // loop condition
3600            let loop_while = self.builder.build_int_compare(
3601                IntPredicate::ULT,
3602                next_index,
3603                blk_end_index,
3604                name,
3605            )?;
3606            let post_block = self.context.append_basic_block(function, name);
3607            self.builder
3608                .build_conditional_branch(loop_while, blk_block, post_block)?;
3609            self.builder.position_at_end(post_block);
3610
3611            // get ready for next blk
3612            block = post_block;
3613            id_index += blk.nnz();
3614        }
3615        self.builder.build_return(None)?;
3616
3617        if function.verify(true) {
3618            Ok(function)
3619        } else {
3620            function.print_to_stderr();
3621            unsafe {
3622                function.delete();
3623            }
3624            Err(anyhow!("Invalid generated function."))
3625        }
3626    }
3627
3628    pub fn module(&self) -> &Module<'ctx> {
3629        &self.module
3630    }
3631}