diffsl/execution/llvm/
codegen.rs

1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use core::panic;
4use inkwell::attributes::{Attribute, AttributeLoc};
5use inkwell::basic_block::BasicBlock;
6use inkwell::builder::Builder;
7use inkwell::context::{AsContextRef, Context};
8use inkwell::debug_info::AsDIScope;
9use inkwell::debug_info::{DICompileUnit, DIFlags, DIFlagsConstants, DebugInfoBuilder};
10use inkwell::execution_engine::ExecutionEngine;
11use inkwell::intrinsics::Intrinsic;
12use inkwell::module::{Linkage, Module};
13use inkwell::passes::PassBuilderOptions;
14use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
15use inkwell::types::{
16    BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
17};
18use inkwell::values::{
19    AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
20    FunctionValue, GlobalValue, IntValue, PointerValue,
21};
22use inkwell::{
23    AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
24    OptimizationLevel,
25};
26use llvm_sys::core::{
27    LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
28    LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
29};
30use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
31use pest::Span;
32use std::collections::HashMap;
33use std::ffi::CString;
34use std::iter::zip;
35use std::pin::Pin;
36use target_lexicon::Triple;
37
38use crate::ast::{Ast, AstKind};
39use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
40use crate::enzyme::{
41    CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
42    CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
43    CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
44    DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
45    EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
46    EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
47    FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
48    CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
49};
50use crate::execution::compiler::CompilerOptions;
51use crate::execution::module::{
52    CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
53};
54use crate::execution::scalar::RealType;
55use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
56use lazy_static::lazy_static;
57use std::sync::Mutex;
58
59lazy_static! {
60    static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
61}
62
63struct ImmovableLlvmModule {
64    // actually has lifetime of `context`
65    // declared first so it's droped before `context`
66    codegen: Option<CodeGen<'static>>,
67    // safety: we must never move out of this box as long as codgen is alive
68    context: AliasableBox<Context>,
69    _pin: std::marker::PhantomPinned,
70}
71
72pub struct LlvmModule {
73    inner: Pin<Box<ImmovableLlvmModule>>,
74    machine: TargetMachine,
75}
76
77unsafe impl Send for LlvmModule {}
78unsafe impl Sync for LlvmModule {}
79
80impl LlvmModule {
81    fn new(
82        triple: Option<Triple>,
83        model: &DiscreteModel,
84        threaded: bool,
85        real_type: RealType,
86        debug: bool,
87    ) -> Result<Self> {
88        let initialization_config = &InitializationConfig::default();
89        Target::initialize_all(initialization_config);
90        let host_triple = Triple::host();
91        let (triple_str, native) = match triple {
92            Some(ref triple) => (triple.to_string(), false),
93            None => (host_triple.to_string(), true),
94        };
95        let triple = TargetTriple::create(triple_str.as_str());
96        let target = Target::from_triple(&triple).unwrap();
97        let cpu = if native {
98            TargetMachine::get_host_cpu_name().to_string()
99        } else {
100            "generic".to_string()
101        };
102        let features = if native {
103            TargetMachine::get_host_cpu_features().to_string()
104        } else {
105            "".to_string()
106        };
107        let machine = target
108            .create_target_machine(
109                &triple,
110                cpu.as_str(),
111                features.as_str(),
112                inkwell::OptimizationLevel::Aggressive,
113                inkwell::targets::RelocMode::Default,
114                inkwell::targets::CodeModel::Default,
115            )
116            .unwrap();
117
118        let context = AliasableBox::from_unique(Box::new(Context::create()));
119        let mut pinned = Self {
120            inner: Box::pin(ImmovableLlvmModule {
121                codegen: None,
122                context,
123                _pin: std::marker::PhantomPinned,
124            }),
125            machine,
126        };
127
128        let context_ref = pinned.inner.context.as_ref();
129        let real_type_llvm = match real_type {
130            RealType::F32 => context_ref.f32_type(),
131            RealType::F64 => context_ref.f64_type(),
132        };
133        let int_type_llvm = context_ref.i32_type();
134        let ptr_size_bits = pinned
135            .machine
136            .get_target_data()
137            .get_bit_size(&context_ref.ptr_type(AddressSpace::default()));
138        let real_size_bits = pinned
139            .machine
140            .get_target_data()
141            .get_bit_size(&real_type_llvm);
142        let int_size_bits = pinned
143            .machine
144            .get_target_data()
145            .get_bit_size(&int_type_llvm);
146        let codegen = CodeGen::new(
147            model,
148            context_ref,
149            real_type,
150            real_type_llvm,
151            int_type_llvm,
152            threaded,
153            ptr_size_bits,
154            real_size_bits,
155            int_size_bits,
156            debug,
157        )?;
158        let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
159        unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
160        Ok(pinned)
161    }
162
163    fn pre_autodiff_optimisation(&mut self) -> Result<()> {
164        //let pass_manager_builder = PassManagerBuilder::create();
165        //pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default);
166        //let pass_manager = PassManager::create(());
167        //pass_manager_builder.populate_module_pass_manager(&pass_manager);
168        //pass_manager.run_on(self.codegen().module());
169
170        //self.codegen().module().print_to_stderr();
171        // optimise at -O2 no unrolling before giving to enzyme
172        let pass_options = PassBuilderOptions::create();
173        //pass_options.set_verify_each(true);
174        //pass_options.set_debug_logging(true);
175        //pass_options.set_loop_interleaving(true);
176        pass_options.set_loop_vectorization(false);
177        pass_options.set_loop_slp_vectorization(false);
178        pass_options.set_loop_unrolling(false);
179        //pass_options.set_forget_all_scev_in_loop_unroll(true);
180        //pass_options.set_licm_mssa_opt_cap(1);
181        //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10);
182        //pass_options.set_call_graph_profile(true);
183        //pass_options.set_merge_functions(true);
184        //let path = "jit_module_before_pre_autodiff_opt.ll";
185        //self.codegen()
186        //    .module()
187        //    .print_to_file(path)
188        //    .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))?;
189
190        //let passes = "default<O2>";
191        let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
192        let (codegen, machine) = self.codegen_and_machine_mut();
193        codegen
194            .module()
195            .run_passes(passes, machine, pass_options)
196            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
197
198        //let path = "jit_module_after_pre_autodiff_opt.ll";
199        //self.codegen()
200        //    .module()
201        //    .print_to_file(path)
202        //    .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))
203    }
204
205    fn post_autodiff_optimisation(&mut self) -> Result<()> {
206        // remove noinline attribute from barrier function as only needed for enzyme
207        if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
208            let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
209            barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
210        }
211
212        // remove all preprocess_* functions
213        for f in self.codegen_mut().module.get_functions() {
214            if f.get_name().to_str().unwrap().starts_with("preprocess_") {
215                unsafe { f.delete() };
216            }
217        }
218
219        //self.codegen()
220        //    .module()
221        //    .print_to_file("jit_module_before_post_autodiff_opt.ll")
222        //    .unwrap();
223
224        let passes = "default<O3>";
225        let (codegen, machine) = self.codegen_and_machine_mut();
226        codegen
227            .module()
228            .run_passes(passes, machine, PassBuilderOptions::create())
229            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
230
231        //self.codegen()
232        //    .module()
233        //    .print_to_file("jit_module_after_post_autodiff_opt.ll")
234        //    .unwrap();
235
236        Ok(())
237    }
238
239    pub fn print(&self) {
240        self.codegen().module().print_to_stderr();
241    }
242    fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
243        unsafe {
244            self.inner
245                .as_mut()
246                .get_unchecked_mut()
247                .codegen
248                .as_mut()
249                .unwrap()
250        }
251    }
252    fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
253        (
254            unsafe {
255                self.inner
256                    .as_mut()
257                    .get_unchecked_mut()
258                    .codegen
259                    .as_mut()
260                    .unwrap()
261            },
262            &self.machine,
263        )
264    }
265
266    fn codegen(&self) -> &CodeGen<'static> {
267        self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
268    }
269}
270
271impl CodegenModule for LlvmModule {}
272
273impl CodegenModuleCompile for LlvmModule {
274    fn from_discrete_model(
275        model: &DiscreteModel,
276        options: CompilerOptions,
277        triple: Option<Triple>,
278        real_type: RealType,
279        code: Option<&str>,
280    ) -> Result<Self> {
281        let thread_dim = options.mode.thread_dim(model.state().nnz());
282        let threaded = thread_dim > 1;
283        if (unsafe { LLVMIsMultithreaded() } <= 0) {
284            return Err(anyhow!(
285                "LLVM is not compiled with multithreading support, but this codegen module requires it."
286            ));
287        }
288
289        let mut module = Self::new(triple, model, threaded, real_type, options.debug)?;
290
291        let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
292        let _calc_stop = module.codegen_mut().compile_calc_stop(model, code)?;
293        let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
294        let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
295        let mass = module.codegen_mut().compile_mass(model, code)?;
296        let calc_out = module.codegen_mut().compile_calc_out(model, false, code)?;
297        let calc_out_full = module.codegen_mut().compile_calc_out(model, true, code)?;
298        let _set_id = module.codegen_mut().compile_set_id(model)?;
299        let _get_dims = module.codegen_mut().compile_get_dims(model)?;
300        let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
301        let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
302        let _set_constants = module.codegen_mut().compile_set_constants(model, code)?;
303        let tensor_info = module
304            .codegen()
305            .layout
306            .tensors()
307            .map(|(name, is_constant)| (name.to_string(), is_constant))
308            .collect::<Vec<_>>();
309        for (tensor, is_constant) in tensor_info {
310            if is_constant {
311                module
312                    .codegen_mut()
313                    .compile_get_constant(model, tensor.as_str())?;
314            } else {
315                module
316                    .codegen_mut()
317                    .compile_get_tensor(model, tensor.as_str())?;
318            }
319        }
320
321        module.pre_autodiff_optimisation()?;
322
323        module.codegen_mut().compile_gradient(
324            set_u0,
325            &[
326                CompileGradientArgType::DupNoNeed,
327                CompileGradientArgType::DupNoNeed,
328                CompileGradientArgType::Const,
329                CompileGradientArgType::Const,
330            ],
331            CompileMode::Forward,
332            "set_u0_grad",
333        )?;
334
335        module.codegen_mut().compile_gradient(
336            rhs,
337            &[
338                CompileGradientArgType::Const,
339                CompileGradientArgType::DupNoNeed,
340                CompileGradientArgType::DupNoNeed,
341                CompileGradientArgType::DupNoNeed,
342                CompileGradientArgType::Const,
343                CompileGradientArgType::Const,
344            ],
345            CompileMode::Forward,
346            "rhs_grad",
347        )?;
348
349        module.codegen_mut().compile_gradient(
350            calc_out,
351            &[
352                CompileGradientArgType::Const,
353                CompileGradientArgType::DupNoNeed,
354                CompileGradientArgType::DupNoNeed,
355                CompileGradientArgType::DupNoNeed,
356                CompileGradientArgType::Const,
357                CompileGradientArgType::Const,
358            ],
359            CompileMode::Forward,
360            "calc_out_grad",
361        )?;
362        module.codegen_mut().compile_gradient(
363            set_inputs,
364            &[
365                CompileGradientArgType::DupNoNeed,
366                CompileGradientArgType::DupNoNeed,
367            ],
368            CompileMode::Forward,
369            "set_inputs_grad",
370        )?;
371
372        module.codegen_mut().compile_gradient(
373            set_u0,
374            &[
375                CompileGradientArgType::DupNoNeed,
376                CompileGradientArgType::DupNoNeed,
377                CompileGradientArgType::Const,
378                CompileGradientArgType::Const,
379            ],
380            CompileMode::Reverse,
381            "set_u0_rgrad",
382        )?;
383
384        module.codegen_mut().compile_gradient(
385            mass,
386            &[
387                CompileGradientArgType::Const,
388                CompileGradientArgType::DupNoNeed,
389                CompileGradientArgType::DupNoNeed,
390                CompileGradientArgType::DupNoNeed,
391                CompileGradientArgType::Const,
392                CompileGradientArgType::Const,
393            ],
394            CompileMode::Reverse,
395            "mass_rgrad",
396        )?;
397
398        module.codegen_mut().compile_gradient(
399            rhs,
400            &[
401                CompileGradientArgType::Const,
402                CompileGradientArgType::DupNoNeed,
403                CompileGradientArgType::DupNoNeed,
404                CompileGradientArgType::DupNoNeed,
405                CompileGradientArgType::Const,
406                CompileGradientArgType::Const,
407            ],
408            CompileMode::Reverse,
409            "rhs_rgrad",
410        )?;
411        module.codegen_mut().compile_gradient(
412            calc_out,
413            &[
414                CompileGradientArgType::Const,
415                CompileGradientArgType::DupNoNeed,
416                CompileGradientArgType::DupNoNeed,
417                CompileGradientArgType::DupNoNeed,
418                CompileGradientArgType::Const,
419                CompileGradientArgType::Const,
420            ],
421            CompileMode::Reverse,
422            "calc_out_rgrad",
423        )?;
424
425        module.codegen_mut().compile_gradient(
426            set_inputs,
427            &[
428                CompileGradientArgType::DupNoNeed,
429                CompileGradientArgType::DupNoNeed,
430            ],
431            CompileMode::Reverse,
432            "set_inputs_rgrad",
433        )?;
434
435        module.codegen_mut().compile_gradient(
436            rhs_full,
437            &[
438                CompileGradientArgType::Const,
439                CompileGradientArgType::Const,
440                CompileGradientArgType::DupNoNeed,
441                CompileGradientArgType::DupNoNeed,
442                CompileGradientArgType::Const,
443                CompileGradientArgType::Const,
444            ],
445            CompileMode::ForwardSens,
446            "rhs_sgrad",
447        )?;
448
449        module.codegen_mut().compile_gradient(
450            set_u0,
451            &[
452                CompileGradientArgType::DupNoNeed,
453                CompileGradientArgType::DupNoNeed,
454                CompileGradientArgType::Const,
455                CompileGradientArgType::Const,
456            ],
457            CompileMode::ForwardSens,
458            "set_u0_sgrad",
459        )?;
460
461        module.codegen_mut().compile_gradient(
462            calc_out_full,
463            &[
464                CompileGradientArgType::Const,
465                CompileGradientArgType::Const,
466                CompileGradientArgType::DupNoNeed,
467                CompileGradientArgType::DupNoNeed,
468                CompileGradientArgType::Const,
469                CompileGradientArgType::Const,
470            ],
471            CompileMode::ForwardSens,
472            "calc_out_sgrad",
473        )?;
474        module.codegen_mut().compile_gradient(
475            calc_out_full,
476            &[
477                CompileGradientArgType::Const,
478                CompileGradientArgType::Const,
479                CompileGradientArgType::DupNoNeed,
480                CompileGradientArgType::DupNoNeed,
481                CompileGradientArgType::Const,
482                CompileGradientArgType::Const,
483            ],
484            CompileMode::ReverseSens,
485            "calc_out_srgrad",
486        )?;
487
488        module.codegen_mut().compile_gradient(
489            rhs_full,
490            &[
491                CompileGradientArgType::Const,
492                CompileGradientArgType::Const,
493                CompileGradientArgType::DupNoNeed,
494                CompileGradientArgType::DupNoNeed,
495                CompileGradientArgType::Const,
496                CompileGradientArgType::Const,
497            ],
498            CompileMode::ReverseSens,
499            "rhs_srgrad",
500        )?;
501
502        module.post_autodiff_optimisation()?;
503        Ok(module)
504    }
505}
506
507impl CodegenModuleEmit for LlvmModule {
508    fn to_object(self) -> Result<Vec<u8>> {
509        let module = self.codegen().module();
510        //module.print_to_stderr();
511        let buffer = self
512            .machine
513            .write_to_memory_buffer(module, FileType::Object)
514            .unwrap()
515            .as_slice()
516            .to_vec();
517        Ok(buffer)
518    }
519}
520impl CodegenModuleJit for LlvmModule {
521    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
522        let ee = self
523            .codegen()
524            .module()
525            .create_jit_execution_engine(OptimizationLevel::Default)
526            .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
527
528        //let path = "jit_module.ll";
529        //self.codegen()
530        //    .module()
531        //    .print_to_file(path)
532        //    .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))?;
533        let module = self.codegen().module();
534        let mut symbols = HashMap::new();
535        for function in module.get_functions() {
536            let name = function.get_name().to_str().unwrap();
537            let address = ee.get_function_address(name);
538            if let Ok(address) = address {
539                symbols.insert(name.to_string(), address as *const u8);
540            }
541        }
542        Ok(symbols)
543    }
544}
545
546struct Globals<'ctx> {
547    indices: Option<GlobalValue<'ctx>>,
548    constants: Option<GlobalValue<'ctx>>,
549    thread_counter: Option<GlobalValue<'ctx>>,
550}
551
552impl<'ctx> Globals<'ctx> {
553    fn new(
554        layout: &DataLayout,
555        module: &Module<'ctx>,
556        int_type: IntType<'ctx>,
557        real_type: FloatType<'ctx>,
558        threaded: bool,
559    ) -> Self {
560        let thread_counter = if threaded {
561            let tc = module.add_global(
562                int_type,
563                Some(AddressSpace::default()),
564                "enzyme_const_thread_counter",
565            );
566            // todo: for some reason this doesn't make enzyme think it's inactive
567            // but using enzyme_const in the name does
568            // todo: also, adding this metadata causes the print of the module to segfault,
569            // so maybe a bug in inkwell
570            //let md_string = context.metadata_string("enzyme_inactive");
571            //tc.set_metadata(md_string, 0);
572            let tc_value = int_type.const_zero();
573            tc.set_visibility(GlobalVisibility::Hidden);
574            tc.set_initializer(&tc_value.as_basic_value_enum());
575            Some(tc)
576        } else {
577            None
578        };
579        let constants = if layout.constants().is_empty() {
580            None
581        } else {
582            let constants_array_type =
583                real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
584            let constants = module.add_global(
585                constants_array_type,
586                Some(AddressSpace::default()),
587                "enzyme_const_constants",
588            );
589            constants.set_visibility(GlobalVisibility::Hidden);
590            constants.set_constant(false);
591            constants.set_initializer(&constants_array_type.const_zero());
592            Some(constants)
593        };
594        let indices = if layout.indices().is_empty() {
595            None
596        } else {
597            let indices_array_type =
598                int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
599            let indices_array_values = layout
600                .indices()
601                .iter()
602                .map(|&i| int_type.const_int(i as u64, true))
603                .collect::<Vec<IntValue>>();
604            let indices_value = int_type.const_array(indices_array_values.as_slice());
605            let indices = module.add_global(
606                indices_array_type,
607                Some(AddressSpace::default()),
608                "enzyme_const_indices",
609            );
610            indices.set_constant(true);
611            indices.set_visibility(GlobalVisibility::Hidden);
612            indices.set_initializer(&indices_value);
613            Some(indices)
614        };
615        Self {
616            indices,
617            thread_counter,
618            constants,
619        }
620    }
621}
622
623pub enum CompileGradientArgType {
624    Const,
625    Dup,
626    DupNoNeed,
627}
628
629pub enum CompileMode {
630    Forward,
631    ForwardSens,
632    Reverse,
633    ReverseSens,
634}
635
636pub struct CodeGen<'ctx> {
637    context: &'ctx inkwell::context::Context,
638    module: Module<'ctx>,
639    builder: Builder<'ctx>,
640    dibuilder: Option<DebugInfoBuilder<'ctx>>,
641    compile_unit: Option<DICompileUnit<'ctx>>,
642    variables: HashMap<String, PointerValue<'ctx>>,
643    functions: HashMap<String, FunctionValue<'ctx>>,
644    fn_value_opt: Option<FunctionValue<'ctx>>,
645    tensor_ptr_opt: Option<PointerValue<'ctx>>,
646    diffsl_real_type: RealType,
647    real_type: FloatType<'ctx>,
648    real_ptr_type: PointerType<'ctx>,
649    int_type: IntType<'ctx>,
650    int_ptr_type: PointerType<'ctx>,
651    ptr_size_bits: u64,
652    int_size_bits: u64,
653    real_size_bits: u64,
654    layout: DataLayout,
655    globals: Globals<'ctx>,
656    threaded: bool,
657    _ee: Option<ExecutionEngine<'ctx>>,
658}
659
660unsafe extern "C" fn fwd_handler(
661    _builder: LLVMBuilderRef,
662    _call_instruction: LLVMValueRef,
663    _gutils: *mut GradientUtils,
664    _dcall: *mut LLVMValueRef,
665    _normal_return: *mut LLVMValueRef,
666    _shadow_return: *mut LLVMValueRef,
667) -> u8 {
668    1
669}
670
671unsafe extern "C" fn rev_handler(
672    builder: LLVMBuilderRef,
673    call_instruction: LLVMValueRef,
674    gutils: *mut DiffeGradientUtils,
675    _tape: LLVMValueRef,
676) {
677    let call_block = LLVMGetInstructionParent(call_instruction);
678    let call_function = LLVMGetBasicBlockParent(call_block);
679    let module = LLVMGetGlobalParent(call_function);
680    let name_c_str = CString::new("barrier_grad").unwrap();
681    let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
682    let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
683    let barrier_num = LLVMGetArgOperand(call_instruction, 0);
684    let total_barriers = LLVMGetArgOperand(call_instruction, 1);
685    let thread_count = LLVMGetArgOperand(call_instruction, 2);
686    let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
687    let total_barriers =
688        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
689    let thread_count =
690        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
691    let mut args = [barrier_num, total_barriers, thread_count];
692    let name_c_str = CString::new("").unwrap();
693    LLVMBuildCall2(
694        builder,
695        barrier_func_type,
696        barrier_func,
697        args.as_mut_ptr(),
698        args.len() as u32,
699        name_c_str.as_ptr(),
700    );
701}
702
703#[allow(dead_code)]
704enum PrintValue<'ctx> {
705    Real(FloatValue<'ctx>),
706    Int(IntValue<'ctx>),
707}
708
709impl<'ctx> CodeGen<'ctx> {
710    #[allow(clippy::too_many_arguments)]
711    pub fn new(
712        model: &DiscreteModel,
713        context: &'ctx inkwell::context::Context,
714        diffsl_real_type: RealType,
715        real_type: FloatType<'ctx>,
716        int_type: IntType<'ctx>,
717        threaded: bool,
718        ptr_size_bits: u64,
719        real_size_bits: u64,
720        int_size_bits: u64,
721        debug: bool,
722    ) -> Result<Self> {
723        let builder = context.create_builder();
724        let layout = DataLayout::new(model);
725        let module = context.create_module(model.name());
726        let (dibuilder, compile_unit) = if debug {
727            let (dib, dic) = module.create_debug_info_builder(
728                true,
729                inkwell::debug_info::DWARFSourceLanguage::C,
730                model.name(),
731                ".",
732                "diffsl compiler",
733                true,
734                "",
735                0,
736                "",
737                inkwell::debug_info::DWARFEmissionKind::Full,
738                0,
739                false,
740                false,
741                "",
742                "",
743            );
744            (Some(dib), Some(dic))
745        } else {
746            (None, None)
747        };
748        let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
749        let real_ptr_type = Self::pointer_type(context, real_type.into());
750        let int_ptr_type = Self::pointer_type(context, int_type.into());
751        let mut ret = Self {
752            context,
753            module,
754            builder,
755            dibuilder,
756            compile_unit,
757            real_type,
758            real_ptr_type,
759            variables: HashMap::new(),
760            functions: HashMap::new(),
761            fn_value_opt: None,
762            tensor_ptr_opt: None,
763            layout,
764            diffsl_real_type,
765            int_type,
766            int_ptr_type,
767            globals,
768            threaded,
769            ptr_size_bits,
770            real_size_bits,
771            int_size_bits,
772            _ee: None,
773        };
774        if threaded {
775            ret.compile_barrier_init()?;
776            ret.compile_barrier()?;
777            ret.compile_barrier_grad()?;
778            // todo: think I can remove this unless I want to call enzyme using a llvm pass
779            //ret.globals.add_registered_barrier(ret.context, &ret.module);
780        }
781        Ok(ret)
782    }
783
784    fn start_function(
785        &mut self,
786        function: FunctionValue<'ctx>,
787        _code: Option<&str>,
788    ) -> BasicBlock<'ctx> {
789        let basic_block = self.context.append_basic_block(function, "entry");
790        self.fn_value_opt = Some(function);
791        self.builder.position_at_end(basic_block);
792
793        if let Some(dibuilder) = &self.dibuilder {
794            let scope = self
795                .fn_value_opt
796                .unwrap()
797                .get_subprogram()
798                .unwrap()
799                .as_debug_info_scope();
800            let loc = dibuilder.create_debug_location(self.context, 0, 0, scope, None);
801            self.builder.set_current_debug_location(loc);
802        }
803        basic_block
804    }
805
806    fn add_function(
807        &mut self,
808        name: &str,
809        arg_names: &[&str],
810        arg_types: &[BasicMetadataTypeEnum<'ctx>],
811        linkage: Option<Linkage>,
812        is_real_return: bool,
813    ) -> FunctionValue<'ctx> {
814        let function_type = if is_real_return {
815            self.real_type.fn_type(arg_types, false)
816        } else {
817            self.context.void_type().fn_type(arg_types, false)
818        };
819        let function = self.module.add_function(name, function_type, linkage);
820        if let (Some(dibuilder), Some(compile_unit)) = (&self.dibuilder, &self.compile_unit) {
821            let ditypes = arg_names
822                .iter()
823                .zip(arg_types.iter())
824                .map(|(&name, &ty)| {
825                    let size_in_bits = if ty.is_float_type() {
826                        self.real_size_bits
827                    } else if ty.is_int_type() {
828                        self.int_size_bits
829                    } else if ty.is_pointer_type() {
830                        self.ptr_size_bits
831                    } else {
832                        unreachable!("Unsupported argument type for debug info")
833                    };
834                    dibuilder
835                        .create_basic_type(
836                            name,
837                            size_in_bits,
838                            0x00,
839                            <DIFlags as DIFlagsConstants>::PUBLIC,
840                        )
841                        .unwrap()
842                        .as_type()
843                })
844                .collect::<Vec<_>>();
845            let subroutine_type = dibuilder.create_subroutine_type(
846                compile_unit.get_file(),
847                None,
848                &ditypes,
849                <DIFlags as DIFlagsConstants>::PUBLIC,
850            );
851            let func_scope = dibuilder.create_function(
852                compile_unit.as_debug_info_scope(),
853                name,
854                None,
855                compile_unit.get_file(),
856                0,
857                subroutine_type,
858                true,
859                true,
860                0,
861                <DIFlags as DIFlagsConstants>::PUBLIC,
862                true,
863            );
864            function.set_subprogram(func_scope);
865        }
866        self.functions.insert(name.to_owned(), function);
867        function
868    }
869
870    #[allow(dead_code)]
871    fn compile_print_value(
872        &mut self,
873        name: &str,
874        value: PrintValue<'ctx>,
875    ) -> Result<CallSiteValue<'_>> {
876        // get printf function or declare it if it doesn't exist
877        let printf = match self.module.get_function("printf") {
878            Some(f) => f,
879            // int printf(const char *format, ...)
880            None => self.add_function(
881                "printf",
882                &["format"],
883                &[self.int_ptr_type.into()],
884                Some(Linkage::External),
885                false,
886            ),
887        };
888        let (format_str, format_str_name) = match value {
889            PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
890            PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
891        };
892        // change format_str to c string
893        let format_str = CString::new(format_str).unwrap();
894        // if format_str_name doesn not already exist as a global, add it
895        let format_str_global = match self.module.get_global(format_str_name.as_str()) {
896            Some(g) => g,
897            None => {
898                let format_str = self.context.const_string(format_str.as_bytes(), true);
899                let fmt_str =
900                    self.module
901                        .add_global(format_str.get_type(), None, format_str_name.as_str());
902                fmt_str.set_initializer(&format_str);
903                fmt_str.set_visibility(GlobalVisibility::Hidden);
904                fmt_str
905            }
906        };
907        // call printf with the format string and the value
908        let format_str_ptr = self.builder.build_pointer_cast(
909            format_str_global.as_pointer_value(),
910            self.int_ptr_type,
911            "format_str_ptr",
912        )?;
913        let value: BasicMetadataValueEnum = match value {
914            PrintValue::Real(v) => v.into(),
915            PrintValue::Int(v) => v.into(),
916        };
917        self.builder
918            .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
919            .map_err(|e| anyhow!("Error building call to printf: {}", e))
920    }
921
922    fn compile_set_constants(
923        &mut self,
924        model: &DiscreteModel,
925        code: Option<&str>,
926    ) -> Result<FunctionValue<'ctx>> {
927        self.clear();
928        let fn_arg_names = &["thread_id", "thread_dim"];
929        let function = self.add_function(
930            "set_constants",
931            fn_arg_names,
932            &[self.int_type.into(), self.int_type.into()],
933            None,
934            false,
935        );
936        let _basic_block = self.start_function(function, code);
937
938        for (i, arg) in function.get_param_iter().enumerate() {
939            let name = fn_arg_names[i];
940            let alloca = self.function_arg_alloca(name, arg);
941            self.insert_param(name, alloca);
942        }
943
944        self.insert_indices();
945        self.insert_constants(model);
946
947        let mut nbarriers = 0;
948        let total_barriers = (model.constant_defns().len()) as u64;
949        #[allow(clippy::explicit_counter_loop)]
950        for a in model.constant_defns() {
951            self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
952            self.jit_compile_call_barrier(nbarriers, total_barriers);
953            nbarriers += 1;
954        }
955
956        self.builder.build_return(None)?;
957
958        if function.verify(true) {
959            Ok(function)
960        } else {
961            function.print_to_stderr();
962            self.module.print_to_stderr();
963            unsafe {
964                function.delete();
965            }
966            Err(anyhow!("Invalid generated function."))
967        }
968    }
969
970    fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
971        self.clear();
972        let function = self.add_function("barrier_init", &[], &[], None, false);
973        let _entry_block = self.start_function(function, None);
974
975        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
976        self.builder
977            .build_store(thread_counter, self.int_type.const_zero())?;
978
979        self.builder.build_return(None)?;
980
981        if function.verify(true) {
982            Ok(function)
983        } else {
984            function.print_to_stderr();
985            unsafe {
986                function.delete();
987            }
988            Err(anyhow!("Invalid generated function."))
989        }
990    }
991
992    fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
993        self.clear();
994        let function = self.add_function(
995            "barrier",
996            &["barrier_num", "total_barriers", "thread_count"],
997            &[
998                self.int_type.into(),
999                self.int_type.into(),
1000                self.int_type.into(),
1001            ],
1002            None,
1003            false,
1004        );
1005        let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
1006        let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
1007        function.add_attribute(AttributeLoc::Function, noinline);
1008
1009        let _entry_block = self.start_function(function, None);
1010        let increment_block = self.context.append_basic_block(function, "increment");
1011        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1012        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1013
1014        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1015        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1016        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1017        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1018
1019        let nbarrier_equals_total_barriers = self
1020            .builder
1021            .build_int_compare(
1022                IntPredicate::EQ,
1023                barrier_num,
1024                total_barriers,
1025                "nbarrier_equals_total_barriers",
1026            )
1027            .unwrap();
1028        // branch to barrier_done if nbarrier == total_barriers
1029        self.builder.build_conditional_branch(
1030            nbarrier_equals_total_barriers,
1031            barrier_done_block,
1032            increment_block,
1033        )?;
1034        self.builder.position_at_end(increment_block);
1035
1036        let barrier_num_times_thread_count = self
1037            .builder
1038            .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
1039            .unwrap();
1040
1041        // Atomically increment the barrier counter
1042        let i32_type = self.context.i32_type();
1043        let one = i32_type.const_int(1, false);
1044        self.builder.build_atomicrmw(
1045            AtomicRMWBinOp::Add,
1046            thread_counter,
1047            one,
1048            AtomicOrdering::Monotonic,
1049        )?;
1050
1051        // wait_loop:
1052        self.builder.build_unconditional_branch(wait_loop_block)?;
1053        self.builder.position_at_end(wait_loop_block);
1054
1055        let current_value = self
1056            .builder
1057            .build_load(i32_type, thread_counter, "current_value")?
1058            .into_int_value();
1059
1060        current_value
1061            .as_instruction_value()
1062            .unwrap()
1063            .set_atomic_ordering(AtomicOrdering::Monotonic)
1064            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1065
1066        let all_threads_done = self.builder.build_int_compare(
1067            IntPredicate::UGE,
1068            current_value,
1069            barrier_num_times_thread_count,
1070            "all_threads_done",
1071        )?;
1072
1073        self.builder.build_conditional_branch(
1074            all_threads_done,
1075            barrier_done_block,
1076            wait_loop_block,
1077        )?;
1078        self.builder.position_at_end(barrier_done_block);
1079
1080        self.builder.build_return(None)?;
1081
1082        if function.verify(true) {
1083            Ok(function)
1084        } else {
1085            function.print_to_stderr();
1086            unsafe {
1087                function.delete();
1088            }
1089            Err(anyhow!("Invalid generated function."))
1090        }
1091    }
1092
1093    fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
1094        self.clear();
1095        let function = self.add_function(
1096            "barrier_grad",
1097            &["barrier_num", "total_barriers", "thread_count"],
1098            &[
1099                self.int_type.into(),
1100                self.int_type.into(),
1101                self.int_type.into(),
1102            ],
1103            None,
1104            false,
1105        );
1106
1107        let _entry_block = self.start_function(function, None);
1108        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1109        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1110
1111        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1112        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1113        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1114        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1115
1116        let twice_total_barriers = self
1117            .builder
1118            .build_int_mul(
1119                total_barriers,
1120                self.int_type.const_int(2, false),
1121                "twice_total_barriers",
1122            )
1123            .unwrap();
1124        let twice_total_barriers_minus_barrier_num = self
1125            .builder
1126            .build_int_sub(
1127                twice_total_barriers,
1128                barrier_num,
1129                "twice_total_barriers_minus_barrier_num",
1130            )
1131            .unwrap();
1132        let twice_total_barriers_minus_barrier_num_times_thread_count = self
1133            .builder
1134            .build_int_mul(
1135                twice_total_barriers_minus_barrier_num,
1136                thread_count,
1137                "twice_total_barriers_minus_barrier_num_times_thread_count",
1138            )
1139            .unwrap();
1140
1141        // Atomically increment the barrier counter
1142        let i32_type = self.context.i32_type();
1143        let one = i32_type.const_int(1, false);
1144        self.builder.build_atomicrmw(
1145            AtomicRMWBinOp::Add,
1146            thread_counter,
1147            one,
1148            AtomicOrdering::Monotonic,
1149        )?;
1150
1151        // wait_loop:
1152        self.builder.build_unconditional_branch(wait_loop_block)?;
1153        self.builder.position_at_end(wait_loop_block);
1154
1155        let current_value = self
1156            .builder
1157            .build_load(i32_type, thread_counter, "current_value")?
1158            .into_int_value();
1159        current_value
1160            .as_instruction_value()
1161            .unwrap()
1162            .set_atomic_ordering(AtomicOrdering::Monotonic)
1163            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1164
1165        let all_threads_done = self.builder.build_int_compare(
1166            IntPredicate::UGE,
1167            current_value,
1168            twice_total_barriers_minus_barrier_num_times_thread_count,
1169            "all_threads_done",
1170        )?;
1171
1172        self.builder.build_conditional_branch(
1173            all_threads_done,
1174            barrier_done_block,
1175            wait_loop_block,
1176        )?;
1177        self.builder.position_at_end(barrier_done_block);
1178
1179        self.builder.build_return(None)?;
1180
1181        if function.verify(true) {
1182            Ok(function)
1183        } else {
1184            function.print_to_stderr();
1185            unsafe {
1186                function.delete();
1187            }
1188            Err(anyhow!("Invalid generated function."))
1189        }
1190    }
1191
1192    fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1193        if !self.threaded {
1194            return;
1195        }
1196        let thread_dim = self.get_param("thread_dim");
1197        let thread_dim = self
1198            .builder
1199            .build_load(self.int_type, *thread_dim, "thread_dim")
1200            .unwrap()
1201            .into_int_value();
1202        let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1203        let total_barriers = self.int_type.const_int(total_barriers, false);
1204        let barrier = self.get_function("barrier").unwrap();
1205        self.builder
1206            .build_call(
1207                barrier,
1208                &[
1209                    BasicMetadataValueEnum::IntValue(nbarrier),
1210                    BasicMetadataValueEnum::IntValue(total_barriers),
1211                    BasicMetadataValueEnum::IntValue(thread_dim),
1212                ],
1213                "barrier",
1214            )
1215            .unwrap();
1216    }
1217
1218    fn jit_threading_limits(
1219        &mut self,
1220        size: IntValue<'ctx>,
1221    ) -> Result<(
1222        IntValue<'ctx>,
1223        IntValue<'ctx>,
1224        BasicBlock<'ctx>,
1225        BasicBlock<'ctx>,
1226    )> {
1227        let one = self.int_type.const_int(1, false);
1228        let thread_id = self.get_param("thread_id");
1229        let thread_id = self
1230            .builder
1231            .build_load(self.int_type, *thread_id, "thread_id")
1232            .unwrap()
1233            .into_int_value();
1234        let thread_dim = self.get_param("thread_dim");
1235        let thread_dim = self
1236            .builder
1237            .build_load(self.int_type, *thread_dim, "thread_dim")
1238            .unwrap()
1239            .into_int_value();
1240
1241        // start index is i * size / thread_dim
1242        let i_times_size = self
1243            .builder
1244            .build_int_mul(thread_id, size, "i_times_size")?;
1245        let start = self
1246            .builder
1247            .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1248
1249        // the ending index for thread i is (i+1) * size / thread_dim
1250        let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1251        let i_plus_one_times_size =
1252            self.builder
1253                .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1254        let end = self
1255            .builder
1256            .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1257
1258        let test_done = self.builder.get_insert_block().unwrap();
1259        let next_block = self
1260            .context
1261            .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1262        self.builder.position_at_end(next_block);
1263
1264        Ok((start, end, test_done, next_block))
1265    }
1266
1267    fn jit_end_threading(
1268        &mut self,
1269        start: IntValue<'ctx>,
1270        end: IntValue<'ctx>,
1271        test_done: BasicBlock<'ctx>,
1272        next: BasicBlock<'ctx>,
1273    ) -> Result<()> {
1274        let exit = self
1275            .context
1276            .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1277        self.builder.build_unconditional_branch(exit)?;
1278        self.builder.position_at_end(test_done);
1279        // done if start == end
1280        let done = self
1281            .builder
1282            .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1283        self.builder.build_conditional_branch(done, exit, next)?;
1284        self.builder.position_at_end(exit);
1285        Ok(())
1286    }
1287
1288    pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1289        self.module.write_bitcode_to_path(path);
1290    }
1291
1292    fn insert_constants(&mut self, model: &DiscreteModel) {
1293        if let Some(constants) = self.globals.constants.as_ref() {
1294            self.insert_param("constants", constants.as_pointer_value());
1295            for tensor in model.constant_defns() {
1296                self.insert_tensor(tensor, true);
1297            }
1298        }
1299    }
1300
1301    fn insert_data(&mut self, model: &DiscreteModel) {
1302        self.insert_constants(model);
1303
1304        if let Some(input) = model.input() {
1305            self.insert_tensor(input, false);
1306        }
1307        for tensor in model.input_dep_defns() {
1308            self.insert_tensor(tensor, false);
1309        }
1310        for tensor in model.time_dep_defns() {
1311            self.insert_tensor(tensor, false);
1312        }
1313        for tensor in model.state_dep_defns() {
1314            self.insert_tensor(tensor, false);
1315        }
1316    }
1317
1318    fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1319        context.ptr_type(AddressSpace::default())
1320    }
1321
1322    fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1323        context.ptr_type(AddressSpace::default())
1324    }
1325
1326    fn insert_indices(&mut self) {
1327        if let Some(indices) = self.globals.indices.as_ref() {
1328            let i32_type = self.context.i32_type();
1329            let zero = i32_type.const_int(0, false);
1330            let ptr = unsafe {
1331                indices
1332                    .as_pointer_value()
1333                    .const_in_bounds_gep(i32_type, &[zero])
1334            };
1335            self.variables.insert("indices".to_owned(), ptr);
1336        }
1337    }
1338
1339    fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1340        self.variables.insert(name.to_owned(), value);
1341    }
1342
1343    fn build_gep<T: BasicType<'ctx>>(
1344        &self,
1345        ty: T,
1346        ptr: PointerValue<'ctx>,
1347        ordered_indexes: &[IntValue<'ctx>],
1348        name: &str,
1349    ) -> Result<PointerValue<'ctx>> {
1350        unsafe {
1351            self.builder
1352                .build_gep(ty, ptr, ordered_indexes, name)
1353                .map_err(|e| e.into())
1354        }
1355    }
1356
1357    fn build_load<T: BasicType<'ctx>>(
1358        &self,
1359        ty: T,
1360        ptr: PointerValue<'ctx>,
1361        name: &str,
1362    ) -> Result<BasicValueEnum<'ctx>> {
1363        self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1364    }
1365
1366    fn get_ptr_to_index<T: BasicType<'ctx>>(
1367        builder: &Builder<'ctx>,
1368        ty: T,
1369        ptr: &PointerValue<'ctx>,
1370        index: IntValue<'ctx>,
1371        name: &str,
1372    ) -> PointerValue<'ctx> {
1373        unsafe {
1374            builder
1375                .build_in_bounds_gep(ty, *ptr, &[index], name)
1376                .unwrap()
1377        }
1378    }
1379
1380    fn insert_state(&mut self, u: &Tensor) {
1381        let mut data_index = 0;
1382        for blk in u.elmts() {
1383            if let Some(name) = blk.name() {
1384                let ptr = self.variables.get("u").unwrap();
1385                let i = self
1386                    .context
1387                    .i32_type()
1388                    .const_int(data_index.try_into().unwrap(), false);
1389                let alloca = Self::get_ptr_to_index(
1390                    &self.create_entry_block_builder(),
1391                    self.real_type,
1392                    ptr,
1393                    i,
1394                    blk.name().unwrap(),
1395                );
1396                self.variables.insert(name.to_owned(), alloca);
1397            }
1398            data_index += blk.nnz();
1399        }
1400    }
1401    fn insert_dot_state(&mut self, dudt: &Tensor) {
1402        let mut data_index = 0;
1403        for blk in dudt.elmts() {
1404            if let Some(name) = blk.name() {
1405                let ptr = self.variables.get("dudt").unwrap();
1406                let i = self
1407                    .context
1408                    .i32_type()
1409                    .const_int(data_index.try_into().unwrap(), false);
1410                let alloca = Self::get_ptr_to_index(
1411                    &self.create_entry_block_builder(),
1412                    self.real_type,
1413                    ptr,
1414                    i,
1415                    blk.name().unwrap(),
1416                );
1417                self.variables.insert(name.to_owned(), alloca);
1418            }
1419            data_index += blk.nnz();
1420        }
1421    }
1422    fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1423        let var_name = if is_constant { "constants" } else { "data" };
1424        let ptr = *self.variables.get(var_name).unwrap();
1425        let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1426        let i = self
1427            .context
1428            .i32_type()
1429            .const_int(data_index.try_into().unwrap(), false);
1430        let alloca = Self::get_ptr_to_index(
1431            &self.create_entry_block_builder(),
1432            self.real_type,
1433            &ptr,
1434            i,
1435            tensor.name(),
1436        );
1437        self.variables.insert(tensor.name().to_owned(), alloca);
1438
1439        //insert any named blocks
1440        for blk in tensor.elmts() {
1441            if let Some(name) = blk.name() {
1442                let i = self
1443                    .context
1444                    .i32_type()
1445                    .const_int(data_index.try_into().unwrap(), false);
1446                let alloca = Self::get_ptr_to_index(
1447                    &self.create_entry_block_builder(),
1448                    self.real_type,
1449                    &ptr,
1450                    i,
1451                    name,
1452                );
1453                self.variables.insert(name.to_owned(), alloca);
1454            }
1455            // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index
1456            data_index += blk.nnz();
1457        }
1458    }
1459    fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1460        self.variables.get(name).unwrap()
1461    }
1462
1463    fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1464        self.variables.get(tensor.name()).unwrap()
1465    }
1466
1467    fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1468        // Check cache for function
1469        if let Some(&func) = self.functions.get(name) {
1470            return Some(func);
1471        }
1472        let current_block = self.builder.get_insert_block().unwrap();
1473        let current_loc = self.builder.get_current_debug_location().unwrap();
1474        let current_function = self.fn_value_opt.unwrap();
1475
1476        let function = match name {
1477            // support some llvm intrinsics
1478            "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" | "copysign"
1479            | "pow" | "min" | "max" => {
1480                let intrinsic_name = match name {
1481                    "min" => "minnum",
1482                    "max" => "maxnum",
1483                    "abs" => "fabs",
1484                    _ => name,
1485                };
1486                let llvm_name =
1487                    format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
1488
1489                let args_types: Vec<BasicTypeEnum> = vec![self.real_type.into()];
1490                self.builder.unset_current_debug_location();
1491
1492                // Try intrinsic first, fall back to libm for all functions
1493                Some(match Intrinsic::find(&llvm_name) {
1494                    Some(intrinsic) => intrinsic.get_declaration(&self.module, &args_types)?,
1495                    None => {
1496                        // Fallback: declare external libm function
1497                        let args_types_meta: Vec<BasicMetadataTypeEnum> =
1498                            vec![self.real_type.into()];
1499                        let function_type = self.real_type.fn_type(&args_types_meta, false);
1500                        self.module
1501                            .add_function(name, function_type, Some(Linkage::External))
1502                    }
1503                })
1504            }
1505            // some custom functions
1506            "sigmoid" => {
1507                let arg_len = 1;
1508                let ret_type = self.real_type;
1509
1510                let args_types = std::iter::repeat_n(ret_type, arg_len)
1511                    .map(|f| f.into())
1512                    .collect::<Vec<BasicMetadataTypeEnum>>();
1513
1514                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1515
1516                for arg in fn_val.get_param_iter() {
1517                    arg.into_float_value().set_name("x");
1518                }
1519                let _basic_block = self.start_function(fn_val, None);
1520
1521                let x = fn_val.get_nth_param(0)?.into_float_value();
1522                let one = self.real_type.const_float(1.0);
1523                let negx = self.builder.build_float_neg(x, name).ok()?;
1524                let exp = self.get_function("exp").unwrap();
1525                let exp_negx = self
1526                    .builder
1527                    .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1528                    .ok()?;
1529                let one_plus_exp_negx = self
1530                    .builder
1531                    .build_float_add(
1532                        exp_negx
1533                            .try_as_basic_value()
1534                            .unwrap_basic()
1535                            .into_float_value(),
1536                        one,
1537                        name,
1538                    )
1539                    .ok()?;
1540                let sigmoid = self
1541                    .builder
1542                    .build_float_div(one, one_plus_exp_negx, name)
1543                    .ok()?;
1544                self.builder.build_return(Some(&sigmoid)).ok();
1545                Some(fn_val)
1546            }
1547            "arcsinh" | "arccosh" => {
1548                let arg_len = 1;
1549                let ret_type = self.real_type;
1550
1551                let args_types = std::iter::repeat_n(ret_type, arg_len)
1552                    .map(|f| f.into())
1553                    .collect::<Vec<BasicMetadataTypeEnum>>();
1554
1555                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1556
1557                for arg in fn_val.get_param_iter() {
1558                    arg.into_float_value().set_name("x");
1559                }
1560
1561                let _basic_block = self.start_function(fn_val, None);
1562                let x = fn_val.get_nth_param(0)?.into_float_value();
1563                let one = match name {
1564                    "arccosh" => self.real_type.const_float(-1.0),
1565                    "arcsinh" => self.real_type.const_float(1.0),
1566                    _ => panic!("unknown function"),
1567                };
1568                let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1569                let one_plus_x_squared = self.builder.build_float_add(x_squared, one, name).ok()?;
1570                let sqrt = self.get_function("sqrt").unwrap();
1571                let sqrt_one_plus_x_squared = self
1572                    .builder
1573                    .build_call(
1574                        sqrt,
1575                        &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1576                        name,
1577                    )
1578                    .unwrap()
1579                    .try_as_basic_value()
1580                    .unwrap_basic()
1581                    .into_float_value();
1582                let x_plus_sqrt_one_plus_x_squared = self
1583                    .builder
1584                    .build_float_add(x, sqrt_one_plus_x_squared, name)
1585                    .ok()?;
1586                let ln = self.get_function("log").unwrap();
1587                let result = self
1588                    .builder
1589                    .build_call(
1590                        ln,
1591                        &[BasicMetadataValueEnum::FloatValue(
1592                            x_plus_sqrt_one_plus_x_squared,
1593                        )],
1594                        name,
1595                    )
1596                    .unwrap()
1597                    .try_as_basic_value()
1598                    .unwrap_basic()
1599                    .into_float_value();
1600                self.builder.build_return(Some(&result)).ok();
1601                Some(fn_val)
1602            }
1603            "heaviside" => {
1604                let arg_len = 1;
1605                let ret_type = self.real_type;
1606
1607                let args_types = std::iter::repeat_n(ret_type, arg_len)
1608                    .map(|f| f.into())
1609                    .collect::<Vec<BasicMetadataTypeEnum>>();
1610
1611                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1612
1613                for arg in fn_val.get_param_iter() {
1614                    arg.into_float_value().set_name("x");
1615                }
1616
1617                let _basic_block = self.start_function(fn_val, None);
1618                let x = fn_val.get_nth_param(0)?.into_float_value();
1619                let zero = self.real_type.const_float(0.0);
1620                let one = self.real_type.const_float(1.0);
1621                let result = self
1622                    .builder
1623                    .build_select(
1624                        self.builder
1625                            .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1626                            .unwrap(),
1627                        one,
1628                        zero,
1629                        name,
1630                    )
1631                    .ok()?;
1632                self.builder.build_return(Some(&result)).ok();
1633                Some(fn_val)
1634            }
1635            "tanh" | "sinh" | "cosh" => {
1636                let arg_len = 1;
1637                let ret_type = self.real_type;
1638
1639                let args_types = std::iter::repeat_n(ret_type, arg_len)
1640                    .map(|f| f.into())
1641                    .collect::<Vec<BasicMetadataTypeEnum>>();
1642
1643                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1644
1645                for arg in fn_val.get_param_iter() {
1646                    arg.into_float_value().set_name("x");
1647                }
1648
1649                let _basic_block = self.start_function(fn_val, None);
1650                let x = fn_val.get_nth_param(0)?.into_float_value();
1651                let negx = self.builder.build_float_neg(x, name).ok()?;
1652                let exp = self.get_function("exp").unwrap();
1653                let exp_negx = self
1654                    .builder
1655                    .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1656                    .ok()?;
1657                let expx = self
1658                    .builder
1659                    .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1660                    .ok()?;
1661                let expx_minus_exp_negx = self
1662                    .builder
1663                    .build_float_sub(
1664                        expx.try_as_basic_value().unwrap_basic().into_float_value(),
1665                        exp_negx
1666                            .try_as_basic_value()
1667                            .unwrap_basic()
1668                            .into_float_value(),
1669                        name,
1670                    )
1671                    .ok()?;
1672                let expx_plus_exp_negx = self
1673                    .builder
1674                    .build_float_add(
1675                        expx.try_as_basic_value().unwrap_basic().into_float_value(),
1676                        exp_negx
1677                            .try_as_basic_value()
1678                            .unwrap_basic()
1679                            .into_float_value(),
1680                        name,
1681                    )
1682                    .ok()?;
1683                let result = match name {
1684                    "tanh" => self
1685                        .builder
1686                        .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1687                        .ok()?,
1688                    "sinh" => self
1689                        .builder
1690                        .build_float_div(expx_minus_exp_negx, self.real_type.const_float(2.0), name)
1691                        .ok()?,
1692                    "cosh" => self
1693                        .builder
1694                        .build_float_div(expx_plus_exp_negx, self.real_type.const_float(2.0), name)
1695                        .ok()?,
1696                    _ => panic!("unknown function"),
1697                };
1698                self.builder.build_return(Some(&result)).ok();
1699                Some(fn_val)
1700            }
1701            _ => None,
1702        }?;
1703
1704        if !function.verify(true) {
1705            function.print_to_stderr();
1706            panic!("Invalid generated function for {}", name);
1707        }
1708
1709        self.builder.position_at_end(current_block);
1710        if self.dibuilder.is_some() {
1711            self.builder.set_current_debug_location(current_loc);
1712        }
1713        self.fn_value_opt = Some(current_function);
1714
1715        self.functions.insert(name.to_owned(), function);
1716        Some(function)
1717    }
1718
1719    /// Returns the `FunctionValue` representing the function being compiled.
1720    #[inline]
1721    fn fn_value(&self) -> FunctionValue<'ctx> {
1722        self.fn_value_opt.unwrap()
1723    }
1724
1725    #[inline]
1726    fn tensor_ptr(&self) -> PointerValue<'ctx> {
1727        self.tensor_ptr_opt.unwrap()
1728    }
1729
1730    /// Creates a new builder in the entry block of the function.
1731    fn create_entry_block_builder(&self) -> Builder<'ctx> {
1732        let builder = self.context.create_builder();
1733        let entry = self.fn_value().get_first_basic_block().unwrap();
1734        match entry.get_first_instruction() {
1735            Some(first_instr) => builder.position_before(&first_instr),
1736            None => builder.position_at_end(entry),
1737        }
1738        builder
1739    }
1740
1741    fn jit_compile_scalar(
1742        &mut self,
1743        a: &Tensor,
1744        res_ptr_opt: Option<PointerValue<'ctx>>,
1745    ) -> Result<PointerValue<'ctx>> {
1746        let res_type = self.real_type;
1747        let res_ptr = match res_ptr_opt {
1748            Some(ptr) => ptr,
1749            None => self
1750                .create_entry_block_builder()
1751                .build_alloca(res_type, a.name())?,
1752        };
1753        let name = a.name();
1754        let elmt = a.elmts().first().unwrap();
1755
1756        // if threaded then only the first thread will evaluate the scalar
1757        let curr_block = self.builder.get_insert_block().unwrap();
1758        let mut next_block_opt = None;
1759        if self.threaded {
1760            let next_block = self.context.append_basic_block(self.fn_value(), "next");
1761            self.builder.position_at_end(next_block);
1762            next_block_opt = Some(next_block);
1763        }
1764
1765        let zero = self.int_type.const_zero();
1766        let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
1767        self.builder.build_store(res_ptr, float_value)?;
1768
1769        // complete the threading block
1770        if self.threaded {
1771            let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1772            self.builder.build_unconditional_branch(exit_block)?;
1773            self.builder.position_at_end(curr_block);
1774
1775            let thread_id = self.get_param("thread_id");
1776            let thread_id = self
1777                .builder
1778                .build_load(self.int_type, *thread_id, "thread_id")
1779                .unwrap()
1780                .into_int_value();
1781            let is_first_thread = self.builder.build_int_compare(
1782                IntPredicate::EQ,
1783                thread_id,
1784                self.int_type.const_zero(),
1785                "is_first_thread",
1786            )?;
1787            self.builder.build_conditional_branch(
1788                is_first_thread,
1789                next_block_opt.unwrap(),
1790                exit_block,
1791            )?;
1792
1793            self.builder.position_at_end(exit_block);
1794        }
1795
1796        Ok(res_ptr)
1797    }
1798
1799    fn jit_compile_tensor(
1800        &mut self,
1801        a: &Tensor,
1802        res_ptr_opt: Option<PointerValue<'ctx>>,
1803        code: Option<&str>,
1804    ) -> Result<PointerValue<'ctx>> {
1805        // treat scalar as a special case
1806        if a.rank() == 0 {
1807            return self.jit_compile_scalar(a, res_ptr_opt);
1808        }
1809
1810        let res_type = self.real_type;
1811        let res_ptr = match res_ptr_opt {
1812            Some(ptr) => ptr,
1813            None => self
1814                .create_entry_block_builder()
1815                .build_alloca(res_type, a.name())?,
1816        };
1817
1818        // set up the tensor storage pointer and index into this data
1819        self.tensor_ptr_opt = Some(res_ptr);
1820
1821        for (i, blk) in a.elmts().iter().enumerate() {
1822            let default = format!("{}-{}", a.name(), i);
1823            let name = blk.name().unwrap_or(default.as_str());
1824            self.jit_compile_block(name, a, blk, code)?;
1825        }
1826        Ok(res_ptr)
1827    }
1828
1829    fn jit_compile_block(
1830        &mut self,
1831        name: &str,
1832        tensor: &Tensor,
1833        elmt: &TensorBlock,
1834        code: Option<&str>,
1835    ) -> Result<()> {
1836        let translation = Translation::new(
1837            elmt.expr_layout(),
1838            elmt.layout(),
1839            elmt.start(),
1840            tensor.layout_ptr(),
1841        );
1842
1843        if let Some(dibuilder) = &self.dibuilder {
1844            let (line_no, column_no) = match (elmt.expr().span, code) {
1845                (Some(s), Some(c)) => {
1846                    let s = Span::new(c, s.pos_start, s.pos_end).unwrap();
1847                    s.start_pos().line_col()
1848                }
1849                _ => (0, 0),
1850            };
1851            let scope = self
1852                .fn_value_opt
1853                .unwrap()
1854                .get_subprogram()
1855                .unwrap()
1856                .as_debug_info_scope();
1857            let loc = dibuilder.create_debug_location(
1858                self.context,
1859                line_no.try_into().unwrap(),
1860                column_no.try_into().unwrap(),
1861                scope,
1862                None,
1863            );
1864            self.builder.set_current_debug_location(loc);
1865        }
1866
1867        if elmt.expr_layout().is_dense() {
1868            self.jit_compile_dense_block(name, elmt, &translation)
1869        } else if elmt.expr_layout().is_diagonal() {
1870            self.jit_compile_diagonal_block(name, elmt, &translation)
1871        } else if elmt.expr_layout().is_sparse() {
1872            match translation.source {
1873                TranslationFrom::SparseContraction { .. } => {
1874                    self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1875                }
1876                _ => self.jit_compile_sparse_block(name, elmt, &translation),
1877            }
1878        } else {
1879            Err(anyhow!(
1880                "unsupported block layout: {:?}",
1881                elmt.expr_layout()
1882            ))
1883        }
1884    }
1885
1886    // for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index
1887    fn jit_compile_dense_block(
1888        &mut self,
1889        name: &str,
1890        elmt: &TensorBlock,
1891        translation: &Translation,
1892    ) -> Result<()> {
1893        let int_type = self.int_type;
1894
1895        let mut preblock = self.builder.get_insert_block().unwrap();
1896        let expr_rank = elmt.expr_layout().rank();
1897        let expr_shape = elmt
1898            .expr_layout()
1899            .shape()
1900            .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1901        let one = int_type.const_int(1, false);
1902
1903        let mut expr_strides = vec![1; expr_rank];
1904        if expr_rank > 0 {
1905            for i in (0..expr_rank - 1).rev() {
1906                expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1907            }
1908        }
1909        let expr_strides = expr_strides
1910            .iter()
1911            .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1912            .collect::<Vec<IntValue>>();
1913
1914        // setup indices, loop through the nested loops
1915        let mut indices = Vec::new();
1916        let mut blocks = Vec::new();
1917
1918        // allocate the contract sum if needed
1919        let (contract_sum, contract_by, contract_strides) =
1920            if let TranslationFrom::DenseContraction {
1921                contract_by,
1922                contract_len: _,
1923            } = translation.source
1924            {
1925                let contract_rank = expr_rank - contract_by;
1926                let mut contract_strides = vec![1; contract_rank];
1927                for i in (0..contract_rank - 1).rev() {
1928                    contract_strides[i] =
1929                        contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1930                }
1931                let contract_strides = contract_strides
1932                    .iter()
1933                    .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1934                    .collect::<Vec<IntValue>>();
1935                (
1936                    Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1937                    contract_by,
1938                    Some(contract_strides),
1939                )
1940            } else {
1941                (None, 0, None)
1942            };
1943
1944        // we will thread the output loop, except if we are contracting to a scalar
1945        let (thread_start, thread_end, test_done, next) = if self.threaded {
1946            let (start, end, test_done, next) =
1947                self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1948            preblock = next;
1949            (Some(start), Some(end), Some(test_done), Some(next))
1950        } else {
1951            (None, None, None, None)
1952        };
1953
1954        for i in 0..expr_rank {
1955            let block = self.context.append_basic_block(self.fn_value(), name);
1956            self.builder.build_unconditional_branch(block)?;
1957            self.builder.position_at_end(block);
1958
1959            let start_index = if i == 0 && self.threaded {
1960                thread_start.unwrap()
1961            } else {
1962                self.int_type.const_zero()
1963            };
1964
1965            let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1966            curr_index.add_incoming(&[(&start_index, preblock)]);
1967
1968            if i == expr_rank - contract_by - 1 {
1969                if let Some(contract_sum) = contract_sum {
1970                    self.builder
1971                        .build_store(contract_sum, self.real_type.const_zero())?;
1972                }
1973            }
1974
1975            indices.push(curr_index);
1976            blocks.push(block);
1977            preblock = block;
1978        }
1979
1980        let indices_int: Vec<IntValue> = indices
1981            .iter()
1982            .map(|i| i.as_basic_value().into_int_value())
1983            .collect();
1984
1985        // if indices = (i, j, k) and shape = (a, b, c) calculate expr_index = (k + j*b + i*b*c)
1986        let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
1987        let mut stride = 1u64;
1988        if !indices.is_empty() {
1989            for i in (0..indices.len() - 1).rev() {
1990                let iname_i = indices_int[i];
1991                let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
1992                stride *= shapei;
1993                let stride_intval = self.context.i32_type().const_int(stride, false);
1994                let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
1995                expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
1996            }
1997        }
1998
1999        let float_value =
2000            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2001
2002        if let Some(contract_sum) = contract_sum {
2003            let contract_sum_value = self
2004                .build_load(self.real_type, contract_sum, "contract_sum")?
2005                .into_float_value();
2006            let new_contract_sum_value = self.builder.build_float_add(
2007                contract_sum_value,
2008                float_value,
2009                "new_contract_sum",
2010            )?;
2011            self.builder
2012                .build_store(contract_sum, new_contract_sum_value)?;
2013        } else {
2014            let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
2015                self.int_type.const_zero(),
2016                |acc, (i, s)| {
2017                    let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
2018                    self.builder.build_int_add(acc, tmp, "acc").unwrap()
2019                },
2020            );
2021            self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
2022        }
2023
2024        let mut postblock = self.builder.get_insert_block().unwrap();
2025
2026        // unwind the nested loops
2027        for i in (0..expr_rank).rev() {
2028            // increment index
2029            let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
2030            indices[i].add_incoming(&[(&next_index, postblock)]);
2031            if i == expr_rank - contract_by - 1 {
2032                if let Some(contract_sum) = contract_sum {
2033                    let contract_sum_value = self
2034                        .build_load(self.real_type, contract_sum, "contract_sum")?
2035                        .into_float_value();
2036                    let contract_strides = contract_strides.as_ref().unwrap();
2037                    let elmt_index = indices_int
2038                        .iter()
2039                        .take(contract_strides.len())
2040                        .zip(contract_strides.iter())
2041                        .fold(self.int_type.const_zero(), |acc, (i, s)| {
2042                            let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
2043                            self.builder.build_int_add(acc, tmp, "acc").unwrap()
2044                        });
2045                    self.jit_compile_store(
2046                        name,
2047                        elmt,
2048                        elmt_index,
2049                        contract_sum_value,
2050                        translation,
2051                    )?;
2052                }
2053            }
2054
2055            let end_index = if i == 0 && self.threaded {
2056                thread_end.unwrap()
2057            } else {
2058                expr_shape[i]
2059            };
2060
2061            // loop condition
2062            let loop_while =
2063                self.builder
2064                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
2065            let block = self.context.append_basic_block(self.fn_value(), name);
2066            self.builder
2067                .build_conditional_branch(loop_while, blocks[i], block)?;
2068            self.builder.position_at_end(block);
2069            postblock = block;
2070        }
2071
2072        if self.threaded {
2073            self.jit_end_threading(
2074                thread_start.unwrap(),
2075                thread_end.unwrap(),
2076                test_done.unwrap(),
2077                next.unwrap(),
2078            )?;
2079        }
2080        Ok(())
2081    }
2082
2083    fn jit_compile_sparse_contraction_block(
2084        &mut self,
2085        name: &str,
2086        elmt: &TensorBlock,
2087        translation: &Translation,
2088    ) -> Result<()> {
2089        match translation.source {
2090            TranslationFrom::SparseContraction { .. } => {}
2091            _ => {
2092                panic!("expected sparse contraction")
2093            }
2094        }
2095        let int_type = self.int_type;
2096
2097        let translation_index = self
2098            .layout
2099            .get_translation_index(elmt.expr_layout(), elmt.layout())
2100            .unwrap();
2101        let translation_index = translation_index + translation.get_from_index_in_data_layout();
2102
2103        let final_contract_index =
2104            int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
2105        let (thread_start, thread_end, test_done, next) = if self.threaded {
2106            let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
2107            (Some(start), Some(end), Some(test_done), Some(next))
2108        } else {
2109            (None, None, None, None)
2110        };
2111
2112        let preblock = self.builder.get_insert_block().unwrap();
2113        let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
2114
2115        // loop through each contraction
2116        let block = self.context.append_basic_block(self.fn_value(), name);
2117        self.builder.build_unconditional_branch(block)?;
2118        self.builder.position_at_end(block);
2119
2120        let contract_index = self.builder.build_phi(int_type, "i")?;
2121        let contract_start = if self.threaded {
2122            thread_start.unwrap()
2123        } else {
2124            int_type.const_zero()
2125        };
2126        contract_index.add_incoming(&[(&contract_start, preblock)]);
2127
2128        let start_index = self.builder.build_int_add(
2129            int_type.const_int(translation_index.try_into().unwrap(), false),
2130            self.builder.build_int_mul(
2131                int_type.const_int(2, false),
2132                contract_index.as_basic_value().into_int_value(),
2133                name,
2134            )?,
2135            name,
2136        )?;
2137        let end_index =
2138            self.builder
2139                .build_int_add(start_index, int_type.const_int(1, false), name)?;
2140        let start_ptr = self.build_gep(
2141            self.int_type,
2142            *self.get_param("indices"),
2143            &[start_index],
2144            "start_index_ptr",
2145        )?;
2146        let start_contract = self
2147            .build_load(self.int_type, start_ptr, "start")?
2148            .into_int_value();
2149        let end_ptr = self.build_gep(
2150            self.int_type,
2151            *self.get_param("indices"),
2152            &[end_index],
2153            "end_index_ptr",
2154        )?;
2155        let end_contract = self
2156            .build_load(self.int_type, end_ptr, "end")?
2157            .into_int_value();
2158
2159        // initialise the contract sum
2160        self.builder
2161            .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
2162
2163        // loop through each element in the contraction
2164        let start_contract_block = self
2165            .context
2166            .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
2167        self.builder
2168            .build_unconditional_branch(start_contract_block)?;
2169        self.builder.position_at_end(start_contract_block);
2170
2171        let expr_index_phi = self.builder.build_phi(int_type, "j")?;
2172        expr_index_phi.add_incoming(&[(&start_contract, block)]);
2173
2174        let expr_index = expr_index_phi.as_basic_value().into_int_value();
2175        let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
2176
2177        // loop body - eval expression and increment sum
2178        let float_value =
2179            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2180        let contract_sum_value = self
2181            .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
2182            .into_float_value();
2183        let new_contract_sum_value =
2184            self.builder
2185                .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
2186        self.builder
2187            .build_store(contract_sum_ptr, new_contract_sum_value)?;
2188
2189        let end_contract_block = self.builder.get_insert_block().unwrap();
2190
2191        // increment contract loop index
2192        let next_elmt_index =
2193            self.builder
2194                .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2195        expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
2196
2197        // contract loop condition
2198        let loop_while = self.builder.build_int_compare(
2199            IntPredicate::ULT,
2200            next_elmt_index,
2201            end_contract,
2202            name,
2203        )?;
2204        let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2205        self.builder.build_conditional_branch(
2206            loop_while,
2207            start_contract_block,
2208            post_contract_block,
2209        )?;
2210        self.builder.position_at_end(post_contract_block);
2211
2212        // store the result
2213        self.jit_compile_store(
2214            name,
2215            elmt,
2216            contract_index.as_basic_value().into_int_value(),
2217            new_contract_sum_value,
2218            translation,
2219        )?;
2220
2221        // increment outer loop index
2222        let next_contract_index = self.builder.build_int_add(
2223            contract_index.as_basic_value().into_int_value(),
2224            int_type.const_int(1, false),
2225            name,
2226        )?;
2227        contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2228
2229        // outer loop condition
2230        let loop_while = self.builder.build_int_compare(
2231            IntPredicate::ULT,
2232            next_contract_index,
2233            thread_end.unwrap_or(final_contract_index),
2234            name,
2235        )?;
2236        let post_block = self.context.append_basic_block(self.fn_value(), name);
2237        self.builder
2238            .build_conditional_branch(loop_while, block, post_block)?;
2239        self.builder.position_at_end(post_block);
2240
2241        if self.threaded {
2242            self.jit_end_threading(
2243                thread_start.unwrap(),
2244                thread_end.unwrap(),
2245                test_done.unwrap(),
2246                next.unwrap(),
2247            )?;
2248        }
2249
2250        Ok(())
2251    }
2252
2253    fn expr_indices_from_elmt_index(
2254        &mut self,
2255        elmt_index: IntValue<'ctx>,
2256        elmt: &TensorBlock,
2257        name: &str,
2258    ) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
2259        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2260        let int_type = self.int_type;
2261        // loop body - load index from layout
2262        let elmt_index_mult_rank = self.builder.build_int_mul(
2263            elmt_index,
2264            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2265            name,
2266        )?;
2267        (0..elmt.expr_layout().rank())
2268            .map(|i| {
2269                let layout_index_plus_offset =
2270                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
2271                let curr_index = self.builder.build_int_add(
2272                    elmt_index_mult_rank,
2273                    layout_index_plus_offset,
2274                    name,
2275                )?;
2276                let ptr = Self::get_ptr_to_index(
2277                    &self.builder,
2278                    self.int_type,
2279                    self.get_param("indices"),
2280                    curr_index,
2281                    name,
2282                );
2283                Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2284            })
2285            .collect::<Result<Vec<_>, anyhow::Error>>()
2286    }
2287
2288    // for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index
2289    // TODO: havn't implemented contractions yet
2290    fn jit_compile_sparse_block(
2291        &mut self,
2292        name: &str,
2293        elmt: &TensorBlock,
2294        translation: &Translation,
2295    ) -> Result<()> {
2296        let int_type = self.int_type;
2297
2298        let start_index = int_type.const_int(0, false);
2299        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2300
2301        let (thread_start, thread_end, test_done, next) = if self.threaded {
2302            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2303            (Some(start), Some(end), Some(test_done), Some(next))
2304        } else {
2305            (None, None, None, None)
2306        };
2307
2308        // loop through the non-zero elements
2309        let preblock = self.builder.get_insert_block().unwrap();
2310        let loop_block = self.context.append_basic_block(self.fn_value(), name);
2311        self.builder.build_unconditional_branch(loop_block)?;
2312        self.builder.position_at_end(loop_block);
2313
2314        let curr_index = self.builder.build_phi(int_type, "i")?;
2315        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2316
2317        let elmt_index = curr_index.as_basic_value().into_int_value();
2318        let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
2319
2320        // loop body - eval expression
2321        let float_value =
2322            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2323
2324        self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2325
2326        // jit_compile_expr or jit_compile_broadcast_and_store may have changed the current block
2327        let end_loop_block = self.builder.get_insert_block().unwrap();
2328
2329        // increment loop index
2330        let one = int_type.const_int(1, false);
2331        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2332        curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2333
2334        // loop condition
2335        let loop_while = self.builder.build_int_compare(
2336            IntPredicate::ULT,
2337            next_index,
2338            thread_end.unwrap_or(end_index),
2339            name,
2340        )?;
2341        let post_block = self.context.append_basic_block(self.fn_value(), name);
2342        self.builder
2343            .build_conditional_branch(loop_while, loop_block, post_block)?;
2344        self.builder.position_at_end(post_block);
2345
2346        if self.threaded {
2347            self.jit_end_threading(
2348                thread_start.unwrap(),
2349                thread_end.unwrap(),
2350                test_done.unwrap(),
2351                next.unwrap(),
2352            )?;
2353        }
2354
2355        Ok(())
2356    }
2357
2358    // for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index
2359    fn jit_compile_diagonal_block(
2360        &mut self,
2361        name: &str,
2362        elmt: &TensorBlock,
2363        translation: &Translation,
2364    ) -> Result<()> {
2365        let int_type = self.int_type;
2366
2367        let start_index = int_type.const_int(0, false);
2368        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2369
2370        let (thread_start, thread_end, test_done, next) = if self.threaded {
2371            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2372            (Some(start), Some(end), Some(test_done), Some(next))
2373        } else {
2374            (None, None, None, None)
2375        };
2376
2377        // loop through the non-zero elements
2378        let preblock = self.builder.get_insert_block().unwrap();
2379        let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
2380        self.builder.build_unconditional_branch(start_loop_block)?;
2381        self.builder.position_at_end(start_loop_block);
2382
2383        let curr_index = self.builder.build_phi(int_type, "i")?;
2384        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2385
2386        // loop body - index is just the same for each element
2387        let elmt_index = curr_index.as_basic_value().into_int_value();
2388        let indices_int: Vec<IntValue> =
2389            (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2390
2391        // loop body - eval expression
2392        let float_value =
2393            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2394
2395        // loop body - store result
2396        self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2397
2398        let end_loop_block = self.builder.get_insert_block().unwrap();
2399
2400        // increment loop index
2401        let one = int_type.const_int(1, false);
2402        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2403        curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2404
2405        // loop condition
2406        let loop_while = self.builder.build_int_compare(
2407            IntPredicate::ULT,
2408            next_index,
2409            thread_end.unwrap_or(end_index),
2410            name,
2411        )?;
2412        let post_block = self.context.append_basic_block(self.fn_value(), name);
2413        self.builder
2414            .build_conditional_branch(loop_while, start_loop_block, post_block)?;
2415        self.builder.position_at_end(post_block);
2416
2417        if self.threaded {
2418            self.jit_end_threading(
2419                thread_start.unwrap(),
2420                thread_end.unwrap(),
2421                test_done.unwrap(),
2422                next.unwrap(),
2423            )?;
2424        }
2425
2426        Ok(())
2427    }
2428
2429    fn jit_compile_broadcast_and_store(
2430        &mut self,
2431        name: &str,
2432        elmt: &TensorBlock,
2433        float_value: FloatValue<'ctx>,
2434        expr_index: IntValue<'ctx>,
2435        translation: &Translation,
2436    ) -> Result<()> {
2437        let int_type = self.int_type;
2438        let one = int_type.const_int(1, false);
2439        let zero = int_type.const_int(0, false);
2440        let pre_block = self.builder.get_insert_block().unwrap();
2441        match translation.source {
2442            TranslationFrom::Broadcast {
2443                broadcast_by: _,
2444                broadcast_len,
2445            } => {
2446                let bcast_start_index = zero;
2447                let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2448
2449                // setup loop block
2450                let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2451                self.builder.build_unconditional_branch(bcast_block)?;
2452                self.builder.position_at_end(bcast_block);
2453                let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2454                bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2455
2456                // store value
2457                let store_index = self.builder.build_int_add(
2458                    self.builder
2459                        .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2460                    bcast_index.as_basic_value().into_int_value(),
2461                    "bcast_store_index",
2462                )?;
2463                self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2464
2465                // increment index
2466                let bcast_next_index = self.builder.build_int_add(
2467                    bcast_index.as_basic_value().into_int_value(),
2468                    one,
2469                    name,
2470                )?;
2471                bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2472
2473                // loop condition
2474                let bcast_cond = self.builder.build_int_compare(
2475                    IntPredicate::ULT,
2476                    bcast_next_index,
2477                    bcast_end_index,
2478                    "broadcast_cond",
2479                )?;
2480                let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2481                self.builder
2482                    .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2483                self.builder.position_at_end(post_bcast_block);
2484                Ok(())
2485            }
2486            TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2487                self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2488                Ok(())
2489            }
2490            _ => Err(anyhow!("Invalid translation")),
2491        }
2492    }
2493
2494    fn jit_compile_store(
2495        &mut self,
2496        name: &str,
2497        elmt: &TensorBlock,
2498        store_index: IntValue<'ctx>,
2499        float_value: FloatValue<'ctx>,
2500        translation: &Translation,
2501    ) -> Result<()> {
2502        let int_type = self.int_type;
2503        let res_index = match &translation.target {
2504            TranslationTo::Contiguous { start, end: _ } => {
2505                let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2506                self.builder.build_int_add(start_const, store_index, name)?
2507            }
2508            TranslationTo::Sparse { indices: _ } => {
2509                // load store index from layout
2510                let translate_index = self
2511                    .layout
2512                    .get_translation_index(elmt.expr_layout(), elmt.layout())
2513                    .unwrap();
2514                let translate_store_index =
2515                    translate_index + translation.get_to_index_in_data_layout();
2516                let translate_store_index =
2517                    int_type.const_int(translate_store_index.try_into().unwrap(), false);
2518                let elmt_index_strided = store_index;
2519                let curr_index =
2520                    self.builder
2521                        .build_int_add(elmt_index_strided, translate_store_index, name)?;
2522                let ptr = Self::get_ptr_to_index(
2523                    &self.builder,
2524                    self.int_type,
2525                    self.get_param("indices"),
2526                    curr_index,
2527                    name,
2528                );
2529                self.build_load(self.int_type, ptr, name)?.into_int_value()
2530            }
2531        };
2532
2533        let resi_ptr = Self::get_ptr_to_index(
2534            &self.builder,
2535            self.real_type,
2536            &self.tensor_ptr(),
2537            res_index,
2538            name,
2539        );
2540        self.builder.build_store(resi_ptr, float_value)?;
2541        Ok(())
2542    }
2543
2544    fn jit_compile_expr(
2545        &mut self,
2546        name: &str,
2547        expr: &Ast,
2548        index: &[IntValue<'ctx>],
2549        elmt: &TensorBlock,
2550        expr_index: IntValue<'ctx>,
2551    ) -> Result<FloatValue<'ctx>> {
2552        let name = elmt.name().unwrap_or(name);
2553        match &expr.kind {
2554            AstKind::Binop(binop) => {
2555                let lhs =
2556                    self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2557                let rhs =
2558                    self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2559                match binop.op {
2560                    '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2561                    '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2562                    '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2563                    '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2564                    unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2565                }
2566            }
2567            AstKind::Monop(monop) => {
2568                let child =
2569                    self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2570                match monop.op {
2571                    '-' => Ok(self.builder.build_float_neg(child, name)?),
2572                    unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2573                }
2574            }
2575            AstKind::Call(call) => match self.get_function(call.fn_name) {
2576                Some(function) => {
2577                    let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2578                    for arg in call.args.iter() {
2579                        let arg_val =
2580                            self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2581                        args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2582                    }
2583                    let ret_value = self
2584                        .builder
2585                        .build_call(function, args.as_slice(), name)?
2586                        .try_as_basic_value()
2587                        .unwrap_basic()
2588                        .into_float_value();
2589                    Ok(ret_value)
2590                }
2591                None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2592            },
2593            AstKind::CallArg(arg) => {
2594                self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2595            }
2596            AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2597            AstKind::Name(iname) => {
2598                let ptr = self.get_param(iname.name);
2599                let layout = self.layout.get_layout(iname.name).unwrap();
2600                let iname_elmt_index = if layout.is_dense() {
2601                    // permute indices based on the index chars of this tensor
2602                    let mut no_transform = true;
2603                    let mut iname_index = Vec::new();
2604                    for (i, c) in iname.indices.iter().enumerate() {
2605                        // find the position index of this index char in the tensor's index chars,
2606                        // if it's not found then it must be a contraction index so is at the end
2607                        let pi = elmt
2608                            .indices()
2609                            .iter()
2610                            .position(|x| x == c)
2611                            .unwrap_or(elmt.indices().len());
2612                        // if we are indexing, add the start indice to index[pi]
2613                        if let Some(indice) =
2614                            iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap())
2615                        {
2616                            let start = indice.first.as_ref().kind.as_integer().unwrap();
2617                            let start_intval = self
2618                                .context
2619                                .i32_type()
2620                                .const_int(start.try_into().unwrap(), false);
2621                            // if we are indexing a single element, the index may be out of bounds
2622                            let index_pi = if pi >= index.len() {
2623                                self.context.i32_type().const_int(0, false)
2624                            } else {
2625                                index[pi]
2626                            };
2627                            let index_pi =
2628                                self.builder.build_int_add(index_pi, start_intval, name)?;
2629                            iname_index.push(index_pi);
2630                        } else {
2631                            iname_index.push(index[pi]);
2632                        }
2633                        no_transform = no_transform && pi == i;
2634                    }
2635                    // broadcasting, if the shape is 1 then the index is always 0
2636                    let iname_index: Vec<IntValue> = iname_index
2637                        .iter()
2638                        .enumerate()
2639                        .map(|(i, &idx)| {
2640                            if layout.shape()[i] == 1 {
2641                                self.context.i32_type().const_int(0, false)
2642                            } else {
2643                                idx
2644                            }
2645                        })
2646                        .collect();
2647                    // calculate the element index using iname_index and the shape of the tensor
2648                    // TODO: can we optimise this by using expr_index, and also including elmt_index?
2649                    if !iname_index.is_empty() {
2650                        // if iname_index is (a, b, c) and shape is s calculate iname_elmt_index = ( c * s[2] + b * s[2] * s[1] + a * s[2] * s[1] * s[0] )
2651                        let mut iname_elmt_index = *iname_index.last().unwrap();
2652                        let mut stride = 1u64;
2653                        for i in (0..iname_index.len() - 1).rev() {
2654                            let iname_i = iname_index[i];
2655                            let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2656                            stride *= shapei;
2657                            let stride_intval = self.context.i32_type().const_int(stride, false);
2658                            let stride_mul_i =
2659                                self.builder.build_int_mul(stride_intval, iname_i, name)?;
2660                            iname_elmt_index =
2661                                self.builder
2662                                    .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2663                        }
2664                        iname_elmt_index
2665                    } else {
2666                        // zero if we are not indexing, otherwise use the start value of indice
2667                        let zero = self.context.i32_type().const_int(0, false);
2668                        zero
2669                    }
2670                } else if layout.is_sparse() || layout.is_diagonal() {
2671                    let expr_layout = elmt.expr_layout();
2672
2673                    if expr_layout != layout {
2674                        // get correct index from binary layout map, ie. indices[ binary_layout_index + expr_index ]
2675                        // if its a -1 then return a 0
2676                        // ie. expr_index = binary_layout[expr_index]
2677                        //.    if expr_index == -1 then return 0 as the value of the expression
2678                        //.    otherwise load the value at that index
2679                        // we are doing an if statement so I think we need to return early here
2680                        let permutation =
2681                            DataLayout::permutation(elmt, iname.indices.as_slice(), layout);
2682                        if let Some(base_binary_layout_index) =
2683                            self.layout
2684                                .get_binary_layout_index(layout, expr_layout, permutation)
2685                        {
2686                            let binary_layout_index = self.builder.build_int_add(
2687                                self.int_type
2688                                    .const_int(base_binary_layout_index.try_into().unwrap(), false),
2689                                expr_index,
2690                                name,
2691                            )?;
2692
2693                            let indices_ptr = Self::get_ptr_to_index(
2694                                &self.builder,
2695                                self.int_type,
2696                                self.get_param("indices"),
2697                                binary_layout_index,
2698                                name,
2699                            );
2700                            let mapped_index = self
2701                                .build_load(self.int_type, indices_ptr, name)?
2702                                .into_int_value();
2703                            let is_less_than_zero = self.builder.build_int_compare(
2704                                IntPredicate::SLT,
2705                                mapped_index,
2706                                self.int_type.const_int(0, true),
2707                                "sparse_index_check",
2708                            )?;
2709                            let is_less_than_zero_block =
2710                                self.context.append_basic_block(self.fn_value(), "lt_zero");
2711                            let not_less_than_zero_block = self
2712                                .context
2713                                .append_basic_block(self.fn_value(), "not_lt_zero");
2714                            let merge_block =
2715                                self.context.append_basic_block(self.fn_value(), "merge");
2716                            self.builder.build_conditional_branch(
2717                                is_less_than_zero,
2718                                is_less_than_zero_block,
2719                                not_less_than_zero_block,
2720                            )?;
2721
2722                            // if mapped index < 0 return 0
2723                            self.builder.position_at_end(is_less_than_zero_block);
2724                            let zero_value = self.real_type.const_float(0.);
2725                            self.builder.build_unconditional_branch(merge_block)?;
2726
2727                            // if mapped index >=0 load value at that index
2728                            self.builder.position_at_end(not_less_than_zero_block);
2729                            let value_ptr = Self::get_ptr_to_index(
2730                                &self.builder,
2731                                self.real_type,
2732                                ptr,
2733                                mapped_index,
2734                                name,
2735                            );
2736                            let value = self
2737                                .build_load(self.real_type, value_ptr, name)?
2738                                .into_float_value();
2739                            self.builder.build_unconditional_branch(merge_block)?;
2740
2741                            // return value or 0 from if statement
2742                            self.builder.position_at_end(merge_block);
2743                            let if_return_value =
2744                                self.builder.build_phi(self.real_type, "sparse_value")?;
2745                            if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
2746                            if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
2747
2748                            let phi_value = if_return_value.as_basic_value().into_float_value();
2749                            return Ok(phi_value);
2750                        } else {
2751                            expr_index
2752                        }
2753                    } else {
2754                        // we can just use the elmt_index since the layouts are the same
2755                        expr_index
2756                    }
2757                } else {
2758                    panic!("unexpected layout");
2759                };
2760                let value_ptr = Self::get_ptr_to_index(
2761                    &self.builder,
2762                    self.real_type,
2763                    ptr,
2764                    iname_elmt_index,
2765                    name,
2766                );
2767                Ok(self
2768                    .build_load(self.real_type, value_ptr, name)?
2769                    .into_float_value())
2770            }
2771            AstKind::NamedGradient(name) => {
2772                let name_str = name.to_string();
2773                let ptr = self.get_param(name_str.as_str());
2774                Ok(self
2775                    .build_load(self.real_type, *ptr, name_str.as_str())?
2776                    .into_float_value())
2777            }
2778            AstKind::Index(_) => todo!(),
2779            AstKind::Slice(_) => todo!(),
2780            AstKind::Integer(_) => todo!(),
2781            _ => panic!("unexprected astkind"),
2782        }
2783    }
2784
2785    fn clear(&mut self) {
2786        self.variables.clear();
2787        //self.functions.clear();
2788        self.fn_value_opt = None;
2789        self.tensor_ptr_opt = None;
2790    }
2791
2792    fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2793        match arg {
2794            BasicValueEnum::PointerValue(v) => v,
2795            BasicValueEnum::FloatValue(v) => {
2796                let alloca = self
2797                    .create_entry_block_builder()
2798                    .build_alloca(arg.get_type(), name)
2799                    .unwrap();
2800                self.builder.build_store(alloca, v).unwrap();
2801                alloca
2802            }
2803            BasicValueEnum::IntValue(v) => {
2804                let alloca = self
2805                    .create_entry_block_builder()
2806                    .build_alloca(arg.get_type(), name)
2807                    .unwrap();
2808                self.builder.build_store(alloca, v).unwrap();
2809                alloca
2810            }
2811            _ => unreachable!(),
2812        }
2813    }
2814
2815    pub fn compile_set_u0<'m>(
2816        &mut self,
2817        model: &'m DiscreteModel,
2818        code: Option<&str>,
2819    ) -> Result<FunctionValue<'ctx>> {
2820        self.clear();
2821        let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2822        let function = self.add_function(
2823            "set_u0",
2824            fn_arg_names,
2825            &[
2826                self.real_ptr_type.into(),
2827                self.real_ptr_type.into(),
2828                self.int_type.into(),
2829                self.int_type.into(),
2830            ],
2831            None,
2832            false,
2833        );
2834
2835        // add noalias
2836        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2837        let noalign = self.context.create_enum_attribute(alias_id, 0);
2838        for i in &[0, 1] {
2839            function.add_attribute(AttributeLoc::Param(*i), noalign);
2840        }
2841
2842        let _basic_block = self.start_function(function, code);
2843
2844        for (i, arg) in function.get_param_iter().enumerate() {
2845            let name = fn_arg_names[i];
2846            let alloca = self.function_arg_alloca(name, arg);
2847            self.insert_param(name, alloca);
2848        }
2849
2850        self.insert_data(model);
2851        self.insert_indices();
2852
2853        let mut nbarriers = 0;
2854        let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2855        #[allow(clippy::explicit_counter_loop)]
2856        for a in model.input_dep_defns() {
2857            self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
2858            self.jit_compile_call_barrier(nbarriers, total_barriers);
2859            nbarriers += 1;
2860        }
2861
2862        self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")), code)?;
2863        self.jit_compile_call_barrier(nbarriers, total_barriers);
2864
2865        self.builder.build_return(None)?;
2866
2867        if function.verify(true) {
2868            Ok(function)
2869        } else {
2870            function.print_to_stderr();
2871            unsafe {
2872                function.delete();
2873            }
2874            Err(anyhow!("Invalid generated function."))
2875        }
2876    }
2877
2878    pub fn compile_calc_out<'m>(
2879        &mut self,
2880        model: &'m DiscreteModel,
2881        include_constants: bool,
2882        code: Option<&str>,
2883    ) -> Result<FunctionValue<'ctx>> {
2884        self.clear();
2885        let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2886        let function_name = if include_constants {
2887            "calc_out_full"
2888        } else {
2889            "calc_out"
2890        };
2891        let function = self.add_function(
2892            function_name,
2893            fn_arg_names,
2894            &[
2895                self.real_type.into(),
2896                self.real_ptr_type.into(),
2897                self.real_ptr_type.into(),
2898                self.real_ptr_type.into(),
2899                self.int_type.into(),
2900                self.int_type.into(),
2901            ],
2902            None,
2903            false,
2904        );
2905
2906        // add noalias
2907        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2908        let noalign = self.context.create_enum_attribute(alias_id, 0);
2909        for i in &[1, 2] {
2910            function.add_attribute(AttributeLoc::Param(*i), noalign);
2911        }
2912
2913        let _basic_block = self.start_function(function, code);
2914
2915        for (i, arg) in function.get_param_iter().enumerate() {
2916            let name = fn_arg_names[i];
2917            let alloca = self.function_arg_alloca(name, arg);
2918            self.insert_param(name, alloca);
2919        }
2920
2921        self.insert_state(model.state());
2922        self.insert_data(model);
2923        self.insert_indices();
2924
2925        // print thread_id and thread_dim
2926        //let thread_id = function.get_nth_param(3).unwrap();
2927        //let thread_dim = function.get_nth_param(4).unwrap();
2928        //self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?;
2929        //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?;
2930        if let Some(out) = model.out() {
2931            let mut nbarriers = 0;
2932            let mut total_barriers =
2933                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2934            if include_constants {
2935                total_barriers += model.input_dep_defns().len() as u64;
2936                // calculate time independant definitions
2937                for tensor in model.input_dep_defns() {
2938                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
2939                    self.jit_compile_call_barrier(nbarriers, total_barriers);
2940                    nbarriers += 1;
2941                }
2942            }
2943
2944            // calculate time dependant definitions
2945            for tensor in model.time_dep_defns() {
2946                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
2947                self.jit_compile_call_barrier(nbarriers, total_barriers);
2948                nbarriers += 1;
2949            }
2950
2951            // calculate state dependant definitions
2952            #[allow(clippy::explicit_counter_loop)]
2953            for a in model.state_dep_defns() {
2954                self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
2955                self.jit_compile_call_barrier(nbarriers, total_barriers);
2956                nbarriers += 1;
2957            }
2958
2959            self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())), code)?;
2960            self.jit_compile_call_barrier(nbarriers, total_barriers);
2961        }
2962        self.builder.build_return(None)?;
2963
2964        if function.verify(true) {
2965            Ok(function)
2966        } else {
2967            function.print_to_stderr();
2968            unsafe {
2969                function.delete();
2970            }
2971            Err(anyhow!("Invalid generated function."))
2972        }
2973    }
2974
2975    pub fn compile_calc_stop<'m>(
2976        &mut self,
2977        model: &'m DiscreteModel,
2978        code: Option<&str>,
2979    ) -> Result<FunctionValue<'ctx>> {
2980        self.clear();
2981        let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2982        let function = self.add_function(
2983            "calc_stop",
2984            fn_arg_names,
2985            &[
2986                self.real_type.into(),
2987                self.real_ptr_type.into(),
2988                self.real_ptr_type.into(),
2989                self.real_ptr_type.into(),
2990                self.int_type.into(),
2991                self.int_type.into(),
2992            ],
2993            None,
2994            false,
2995        );
2996
2997        // add noalias
2998        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2999        let noalign = self.context.create_enum_attribute(alias_id, 0);
3000        for i in &[1, 2, 3] {
3001            function.add_attribute(AttributeLoc::Param(*i), noalign);
3002        }
3003
3004        let _basic_block = self.start_function(function, code);
3005
3006        for (i, arg) in function.get_param_iter().enumerate() {
3007            let name = fn_arg_names[i];
3008            let alloca = self.function_arg_alloca(name, arg);
3009            self.insert_param(name, alloca);
3010        }
3011
3012        self.insert_state(model.state());
3013        self.insert_data(model);
3014        self.insert_indices();
3015
3016        if let Some(stop) = model.stop() {
3017            // calculate time dependant definitions
3018            let mut nbarriers = 0;
3019            let total_barriers =
3020                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
3021            for tensor in model.time_dep_defns() {
3022                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3023                self.jit_compile_call_barrier(nbarriers, total_barriers);
3024                nbarriers += 1;
3025            }
3026
3027            // calculate state dependant definitions
3028            for a in model.state_dep_defns() {
3029                self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3030                self.jit_compile_call_barrier(nbarriers, total_barriers);
3031                nbarriers += 1;
3032            }
3033
3034            let res_ptr = self.get_param("root");
3035            self.jit_compile_tensor(stop, Some(*res_ptr), code)?;
3036            self.jit_compile_call_barrier(nbarriers, total_barriers);
3037        }
3038        self.builder.build_return(None)?;
3039
3040        if function.verify(true) {
3041            Ok(function)
3042        } else {
3043            function.print_to_stderr();
3044            unsafe {
3045                function.delete();
3046            }
3047            Err(anyhow!("Invalid generated function."))
3048        }
3049    }
3050
3051    pub fn compile_rhs<'m>(
3052        &mut self,
3053        model: &'m DiscreteModel,
3054        include_constants: bool,
3055        code: Option<&str>,
3056    ) -> Result<FunctionValue<'ctx>> {
3057        self.clear();
3058        let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
3059        let function_name = if include_constants { "rhs_full" } else { "rhs" };
3060        let function = self.add_function(
3061            function_name,
3062            fn_arg_names,
3063            &[
3064                self.real_type.into(),
3065                self.real_ptr_type.into(),
3066                self.real_ptr_type.into(),
3067                self.real_ptr_type.into(),
3068                self.int_type.into(),
3069                self.int_type.into(),
3070            ],
3071            None,
3072            false,
3073        );
3074
3075        // add noalias
3076        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3077        let noalign = self.context.create_enum_attribute(alias_id, 0);
3078        for i in &[1, 2, 3] {
3079            function.add_attribute(AttributeLoc::Param(*i), noalign);
3080        }
3081
3082        let _basic_block = self.start_function(function, code);
3083
3084        for (i, arg) in function.get_param_iter().enumerate() {
3085            let name = fn_arg_names[i];
3086            let alloca = self.function_arg_alloca(name, arg);
3087            self.insert_param(name, alloca);
3088        }
3089
3090        self.insert_state(model.state());
3091        self.insert_data(model);
3092        self.insert_indices();
3093
3094        let mut nbarriers = 0;
3095        let mut total_barriers =
3096            (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
3097        if include_constants {
3098            total_barriers += model.input_dep_defns().len() as u64;
3099            // calculate constant definitions
3100            for tensor in model.input_dep_defns() {
3101                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3102                self.jit_compile_call_barrier(nbarriers, total_barriers);
3103                nbarriers += 1;
3104            }
3105        }
3106
3107        // calculate time dependant definitions
3108        for tensor in model.time_dep_defns() {
3109            self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3110            self.jit_compile_call_barrier(nbarriers, total_barriers);
3111            nbarriers += 1;
3112        }
3113
3114        // TODO: could split state dep defns into before and after F
3115        for a in model.state_dep_defns() {
3116            self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3117            self.jit_compile_call_barrier(nbarriers, total_barriers);
3118            nbarriers += 1;
3119        }
3120
3121        // F
3122        let res_ptr = self.get_param("rr");
3123        self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?;
3124        self.jit_compile_call_barrier(nbarriers, total_barriers);
3125
3126        self.builder.build_return(None)?;
3127
3128        if function.verify(true) {
3129            Ok(function)
3130        } else {
3131            function.print_to_stderr();
3132            unsafe {
3133                function.delete();
3134            }
3135            Err(anyhow!("Invalid generated function."))
3136        }
3137    }
3138
3139    pub fn compile_mass<'m>(
3140        &mut self,
3141        model: &'m DiscreteModel,
3142        code: Option<&str>,
3143    ) -> Result<FunctionValue<'ctx>> {
3144        self.clear();
3145        let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
3146        let function = self.add_function(
3147            "mass",
3148            fn_arg_names,
3149            &[
3150                self.real_type.into(),
3151                self.real_ptr_type.into(),
3152                self.real_ptr_type.into(),
3153                self.real_ptr_type.into(),
3154                self.int_type.into(),
3155                self.int_type.into(),
3156            ],
3157            None,
3158            false,
3159        );
3160
3161        // add noalias
3162        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3163        let noalign = self.context.create_enum_attribute(alias_id, 0);
3164        for i in &[1, 2, 3] {
3165            function.add_attribute(AttributeLoc::Param(*i), noalign);
3166        }
3167
3168        let _basic_block = self.start_function(function, code);
3169
3170        for (i, arg) in function.get_param_iter().enumerate() {
3171            let name = fn_arg_names[i];
3172            let alloca = self.function_arg_alloca(name, arg);
3173            self.insert_param(name, alloca);
3174        }
3175
3176        // only put code in this function if we have a state_dot and lhs
3177        if model.state_dot().is_some() && model.lhs().is_some() {
3178            let state_dot = model.state_dot().unwrap();
3179            let lhs = model.lhs().unwrap();
3180
3181            self.insert_dot_state(state_dot);
3182            self.insert_data(model);
3183            self.insert_indices();
3184
3185            // calculate time dependant definitions
3186            let mut nbarriers = 0;
3187            let total_barriers =
3188                (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
3189            for tensor in model.time_dep_defns() {
3190                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3191                self.jit_compile_call_barrier(nbarriers, total_barriers);
3192                nbarriers += 1;
3193            }
3194
3195            for a in model.dstate_dep_defns() {
3196                self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3197                self.jit_compile_call_barrier(nbarriers, total_barriers);
3198                nbarriers += 1;
3199            }
3200
3201            // mass
3202            let res_ptr = self.get_param("rr");
3203            self.jit_compile_tensor(lhs, Some(*res_ptr), code)?;
3204            self.jit_compile_call_barrier(nbarriers, total_barriers);
3205        }
3206
3207        self.builder.build_return(None)?;
3208
3209        if function.verify(true) {
3210            Ok(function)
3211        } else {
3212            function.print_to_stderr();
3213            unsafe {
3214                function.delete();
3215            }
3216            Err(anyhow!("Invalid generated function."))
3217        }
3218    }
3219
3220    pub fn compile_gradient(
3221        &mut self,
3222        original_function: FunctionValue<'ctx>,
3223        args_type: &[CompileGradientArgType],
3224        mode: CompileMode,
3225        fn_name: &str,
3226    ) -> Result<FunctionValue<'ctx>> {
3227        self.clear();
3228
3229        // construct the gradient function
3230        let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
3231
3232        let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
3233
3234        let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
3235        let mut start_param_index: Vec<u32> = Vec::new();
3236        let mut ptr_arg_indices: Vec<u32> = Vec::new();
3237        for (i, arg) in original_function.get_param_iter().enumerate() {
3238            start_param_index.push(u32::try_from(fn_type.len()).unwrap());
3239            let arg_type = arg.get_type();
3240            fn_type.push(arg_type.into());
3241
3242            // constant args with type T in the original funciton have 2 args of type [int, T]
3243            enzyme_fn_type.push(self.int_type.into());
3244            enzyme_fn_type.push(arg.get_type().into());
3245
3246            if arg_type.is_pointer_type() {
3247                ptr_arg_indices.push(u32::try_from(i).unwrap());
3248            }
3249
3250            match args_type[i] {
3251                CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
3252                    fn_type.push(arg.get_type().into());
3253                    enzyme_fn_type.push(arg.get_type().into());
3254                }
3255                CompileGradientArgType::Const => {}
3256            }
3257        }
3258        let fn_arg_names = fn_type
3259            .iter()
3260            .enumerate()
3261            .map(|(i, _)| format!("arg{}", i))
3262            .collect::<Vec<String>>();
3263        let fn_arg_names_ref = fn_arg_names
3264            .iter()
3265            .map(|s| s.as_str())
3266            .collect::<Vec<&str>>();
3267        let function = self.add_function(fn_name, &fn_arg_names_ref, &fn_type, None, false);
3268
3269        // add noalias
3270        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3271        let noalign = self.context.create_enum_attribute(alias_id, 0);
3272        for i in ptr_arg_indices {
3273            function.add_attribute(AttributeLoc::Param(i), noalign);
3274        }
3275
3276        let _basic_block = self.start_function(function, None);
3277
3278        let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
3279        let mut input_activity = Vec::new();
3280        let mut arg_trees = Vec::new();
3281        for (i, arg) in original_function.get_param_iter().enumerate() {
3282            let param_index = start_param_index[i];
3283            let fn_arg = function.get_nth_param(param_index).unwrap();
3284
3285            // we'll probably only get double or pointers to doubles, so let assume this for now
3286            // todo: perhaps refactor this into a recursive function, might be overkill
3287            let concrete_type = match arg.get_type() {
3288                BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
3289                BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
3290                    RealType::F32 => CConcreteType_DT_Float,
3291                    RealType::F64 => CConcreteType_DT_Double,
3292                },
3293                BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
3294                _ => panic!("unsupported type"),
3295            };
3296            let new_tree = unsafe {
3297                EnzymeNewTypeTreeCT(
3298                    concrete_type,
3299                    self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3300                )
3301            };
3302            unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3303
3304            // pointer to real type
3305            if concrete_type == CConcreteType_DT_Pointer {
3306                let inner_concrete_type = match self.diffsl_real_type {
3307                    RealType::F32 => CConcreteType_DT_Float,
3308                    RealType::F64 => CConcreteType_DT_Double,
3309                };
3310                let inner_new_tree = unsafe {
3311                    EnzymeNewTypeTreeCT(
3312                        inner_concrete_type,
3313                        self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3314                    )
3315                };
3316                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3317                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3318                unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
3319            }
3320
3321            arg_trees.push(new_tree);
3322            match args_type[i] {
3323                CompileGradientArgType::Dup => {
3324                    // pass in the arg value
3325                    enzyme_fn_args.push(fn_arg.into());
3326
3327                    // pass in the darg value
3328                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3329                    enzyme_fn_args.push(fn_darg.into());
3330
3331                    input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3332                }
3333                CompileGradientArgType::DupNoNeed => {
3334                    // pass in the arg value
3335                    enzyme_fn_args.push(fn_arg.into());
3336
3337                    // pass in the darg value
3338                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3339                    enzyme_fn_args.push(fn_darg.into());
3340
3341                    input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3342                }
3343                CompileGradientArgType::Const => {
3344                    // pass in the arg value
3345                    enzyme_fn_args.push(fn_arg.into());
3346
3347                    input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3348                }
3349            }
3350        }
3351        // if we have void ret, this must be false;
3352        let ret_primary_ret = false;
3353        let diff_ret = false;
3354        let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3355        let ret_tree = unsafe {
3356            EnzymeNewTypeTreeCT(
3357                CConcreteType_DT_Anything,
3358                self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3359            )
3360        };
3361
3362        // always optimize
3363        let fnc_opt_base = true;
3364        let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3365
3366        let kv_tmp = IntList {
3367            data: std::ptr::null_mut(),
3368            size: 0,
3369        };
3370        let mut known_values = vec![kv_tmp; input_activity.len()];
3371
3372        let fn_type_info = CFnTypeInfo {
3373            Arguments: arg_trees.as_mut_ptr(),
3374            Return: ret_tree,
3375            KnownValues: known_values.as_mut_ptr(),
3376        };
3377
3378        let type_analysis: EnzymeTypeAnalysisRef =
3379            unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3380
3381        let mut args_uncacheable = vec![0; arg_trees.len()];
3382
3383        let enzyme_function = match mode {
3384            CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3385                EnzymeCreateForwardDiff(
3386                    logic_ref, // Logic
3387                    std::ptr::null_mut(),
3388                    std::ptr::null_mut(),
3389                    original_function.as_value_ref(),
3390                    ret_activity, // LLVM function, return type
3391                    input_activity.as_mut_ptr(),
3392                    input_activity.len(), // constant arguments
3393                    type_analysis,        // type analysis struct
3394                    ret_primary_ret as u8,
3395                    CDerivativeMode_DEM_ForwardMode, // return value, dret_used, top_level which was 1
3396                    1,                               // free memory
3397                    0,                               // runtime activity
3398                    0,                               // strong zero
3399                    1,                               // vector mode width
3400                    std::ptr::null_mut(),            // additional argument
3401                    fn_type_info,                    // additional_arg, type info (return + args)
3402                    1,                               // subsequent calls may write
3403                    args_uncacheable.as_mut_ptr(),   // overwritten args
3404                    args_uncacheable.len(),          // overwritten args length
3405                    std::ptr::null_mut(),            // write augmented function to this
3406                )
3407            },
3408            CompileMode::Reverse | CompileMode::ReverseSens => {
3409                let mut call_enzyme = || unsafe {
3410                    EnzymeCreatePrimalAndGradient(
3411                        logic_ref,
3412                        std::ptr::null_mut(),
3413                        std::ptr::null_mut(),
3414                        original_function.as_value_ref(),
3415                        ret_activity,
3416                        input_activity.as_mut_ptr(),
3417                        input_activity.len(),
3418                        type_analysis,
3419                        ret_primary_ret as u8,
3420                        diff_ret as u8,
3421                        CDerivativeMode_DEM_ReverseModeCombined,
3422                        0,
3423                        0, // strong zero
3424                        1,
3425                        1,
3426                        std::ptr::null_mut(),
3427                        0, // force annonymous tape
3428                        fn_type_info,
3429                        0, // subsequent calls may write
3430                        args_uncacheable.as_mut_ptr(),
3431                        args_uncacheable.len(),
3432                        std::ptr::null_mut(),
3433                        if self.threaded { 1 } else { 0 }, // atomic add
3434                    )
3435                };
3436                if self.threaded {
3437                    // the register call handler alters a global variable, so we need to lock it
3438                    let _lock = my_mutex.lock().unwrap();
3439                    let barrier_string = CString::new("barrier").unwrap();
3440                    unsafe {
3441                        EnzymeRegisterCallHandler(
3442                            barrier_string.as_ptr(),
3443                            Some(fwd_handler),
3444                            Some(rev_handler),
3445                        )
3446                    };
3447                    let ret = call_enzyme();
3448                    // unregister it so some other thread doesn't use it
3449                    unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3450                    ret
3451                } else {
3452                    call_enzyme()
3453                }
3454            }
3455        };
3456
3457        // free everything
3458        unsafe { FreeEnzymeLogic(logic_ref) };
3459        unsafe { FreeTypeAnalysis(type_analysis) };
3460        unsafe { EnzymeFreeTypeTree(ret_tree) };
3461        for tree in arg_trees {
3462            unsafe { EnzymeFreeTypeTree(tree) };
3463        }
3464
3465        // call enzyme function
3466        let enzyme_function =
3467            unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3468        self.builder
3469            .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3470
3471        // return
3472        self.builder.build_return(None)?;
3473
3474        if function.verify(true) {
3475            Ok(function)
3476        } else {
3477            function.print_to_stderr();
3478            enzyme_function.print_to_stderr();
3479            unsafe {
3480                function.delete();
3481            }
3482            Err(anyhow!("Invalid generated function."))
3483        }
3484    }
3485
3486    pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3487        self.clear();
3488        let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3489        let function = self.add_function(
3490            "get_dims",
3491            fn_arg_names,
3492            &[
3493                self.int_ptr_type.into(),
3494                self.int_ptr_type.into(),
3495                self.int_ptr_type.into(),
3496                self.int_ptr_type.into(),
3497                self.int_ptr_type.into(),
3498                self.int_ptr_type.into(),
3499            ],
3500            None,
3501            false,
3502        );
3503        let _block = self.start_function(function, None);
3504
3505        for (i, arg) in function.get_param_iter().enumerate() {
3506            let name = fn_arg_names[i];
3507            let alloca = self.function_arg_alloca(name, arg);
3508            self.insert_param(name, alloca);
3509        }
3510
3511        self.insert_indices();
3512
3513        let number_of_states = model.state().nnz() as u64;
3514        let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
3515        let number_of_outputs = match model.out() {
3516            Some(out) => out.nnz() as u64,
3517            None => 0,
3518        };
3519        let number_of_stop = if let Some(stop) = model.stop() {
3520            stop.nnz() as u64
3521        } else {
3522            0
3523        };
3524        let has_mass = match model.lhs().is_some() {
3525            true => 1u64,
3526            false => 0u64,
3527        };
3528        let data_len = self.layout.data().len() as u64;
3529        self.builder.build_store(
3530            *self.get_param("states"),
3531            self.int_type.const_int(number_of_states, false),
3532        )?;
3533        self.builder.build_store(
3534            *self.get_param("inputs"),
3535            self.int_type.const_int(number_of_inputs, false),
3536        )?;
3537        self.builder.build_store(
3538            *self.get_param("outputs"),
3539            self.int_type.const_int(number_of_outputs, false),
3540        )?;
3541        self.builder.build_store(
3542            *self.get_param("data"),
3543            self.int_type.const_int(data_len, false),
3544        )?;
3545        self.builder.build_store(
3546            *self.get_param("stop"),
3547            self.int_type.const_int(number_of_stop, false),
3548        )?;
3549        self.builder.build_store(
3550            *self.get_param("has_mass"),
3551            self.int_type.const_int(has_mass, false),
3552        )?;
3553        self.builder.build_return(None)?;
3554
3555        if function.verify(true) {
3556            Ok(function)
3557        } else {
3558            function.print_to_stderr();
3559            unsafe {
3560                function.delete();
3561            }
3562            Err(anyhow!("Invalid generated function."))
3563        }
3564    }
3565
3566    pub fn compile_get_tensor(
3567        &mut self,
3568        model: &DiscreteModel,
3569        name: &str,
3570    ) -> Result<FunctionValue<'ctx>> {
3571        self.clear();
3572        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3573        let function_name = format!("get_tensor_{name}");
3574        let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3575        let function = self.add_function(
3576            function_name.as_str(),
3577            fn_arg_names,
3578            &[
3579                self.real_ptr_type.into(),
3580                real_ptr_ptr_type.into(),
3581                self.int_ptr_type.into(),
3582            ],
3583            None,
3584            false,
3585        );
3586        let _basic_block = self.start_function(function, None);
3587
3588        for (i, arg) in function.get_param_iter().enumerate() {
3589            let name = fn_arg_names[i];
3590            let alloca = self.function_arg_alloca(name, arg);
3591            self.insert_param(name, alloca);
3592        }
3593
3594        self.insert_data(model);
3595        let ptr = self.get_param(name);
3596        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3597        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3598        self.builder
3599            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3600        self.builder
3601            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3602        self.builder.build_return(None)?;
3603
3604        if function.verify(true) {
3605            Ok(function)
3606        } else {
3607            function.print_to_stderr();
3608            unsafe {
3609                function.delete();
3610            }
3611            Err(anyhow!("Invalid generated function."))
3612        }
3613    }
3614
3615    pub fn compile_get_constant(
3616        &mut self,
3617        model: &DiscreteModel,
3618        name: &str,
3619    ) -> Result<FunctionValue<'ctx>> {
3620        self.clear();
3621        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3622        let function_name = format!("get_constant_{name}");
3623        let fn_arg_names = &["tensor_data", "tensor_size"];
3624        let function = self.add_function(
3625            function_name.as_str(),
3626            fn_arg_names,
3627            &[real_ptr_ptr_type.into(), self.int_ptr_type.into()],
3628            None,
3629            false,
3630        );
3631        let _basic_block = self.start_function(function, None);
3632
3633        for (i, arg) in function.get_param_iter().enumerate() {
3634            let name = fn_arg_names[i];
3635            let alloca = self.function_arg_alloca(name, arg);
3636            self.insert_param(name, alloca);
3637        }
3638
3639        self.insert_constants(model);
3640        let ptr = self.get_param(name);
3641        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3642        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3643        self.builder
3644            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3645        self.builder
3646            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3647        self.builder.build_return(None)?;
3648
3649        if function.verify(true) {
3650            Ok(function)
3651        } else {
3652            function.print_to_stderr();
3653            unsafe {
3654                function.delete();
3655            }
3656            Err(anyhow!("Invalid generated function."))
3657        }
3658    }
3659
3660    pub fn compile_inputs(
3661        &mut self,
3662        model: &DiscreteModel,
3663        is_get: bool,
3664    ) -> Result<FunctionValue<'ctx>> {
3665        self.clear();
3666        let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3667        let fn_arg_names = &["inputs", "data"];
3668        let function = self.add_function(
3669            function_name,
3670            fn_arg_names,
3671            &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3672            None,
3673            false,
3674        );
3675        let block = self.start_function(function, None);
3676
3677        for (i, arg) in function.get_param_iter().enumerate() {
3678            let name = fn_arg_names[i];
3679            let alloca = self.function_arg_alloca(name, arg);
3680            self.insert_param(name, alloca);
3681        }
3682
3683        if let Some(input) = model.input() {
3684            let name = input.name();
3685            self.insert_tensor(input, false);
3686            let ptr = self.get_var(input);
3687            // loop thru the elements of this input and set/get them using the inputs ptr
3688            let inputs_start_index = self.int_type.const_int(0, false);
3689            let start_index = self.int_type.const_int(0, false);
3690            let end_index = self
3691                .int_type
3692                .const_int(input.nnz().try_into().unwrap(), false);
3693
3694            let input_block = self.context.append_basic_block(function, name);
3695            self.builder.build_unconditional_branch(input_block)?;
3696            self.builder.position_at_end(input_block);
3697            let index = self.builder.build_phi(self.int_type, "i")?;
3698            index.add_incoming(&[(&start_index, block)]);
3699
3700            // loop body - copy value from inputs to data
3701            let curr_input_index = index.as_basic_value().into_int_value();
3702            let input_ptr =
3703                Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
3704            let curr_inputs_index =
3705                self.builder
3706                    .build_int_add(inputs_start_index, curr_input_index, name)?;
3707            let inputs_ptr = Self::get_ptr_to_index(
3708                &self.builder,
3709                self.real_type,
3710                self.get_param("inputs"),
3711                curr_inputs_index,
3712                name,
3713            );
3714            if is_get {
3715                let input_value = self
3716                    .build_load(self.real_type, input_ptr, name)?
3717                    .into_float_value();
3718                self.builder.build_store(inputs_ptr, input_value)?;
3719            } else {
3720                let input_value = self
3721                    .build_load(self.real_type, inputs_ptr, name)?
3722                    .into_float_value();
3723                self.builder.build_store(input_ptr, input_value)?;
3724            }
3725
3726            // increment loop index
3727            let one = self.int_type.const_int(1, false);
3728            let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
3729            index.add_incoming(&[(&next_index, input_block)]);
3730
3731            // loop condition
3732            let loop_while =
3733                self.builder
3734                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
3735            let post_block = self.context.append_basic_block(function, name);
3736            self.builder
3737                .build_conditional_branch(loop_while, input_block, post_block)?;
3738            self.builder.position_at_end(post_block);
3739        }
3740        self.builder.build_return(None)?;
3741
3742        if function.verify(true) {
3743            Ok(function)
3744        } else {
3745            function.print_to_stderr();
3746            unsafe {
3747                function.delete();
3748            }
3749            Err(anyhow!("Invalid generated function."))
3750        }
3751    }
3752
3753    pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3754        self.clear();
3755        let fn_arg_names = &["id"];
3756        let function = self.add_function(
3757            "set_id",
3758            fn_arg_names,
3759            &[self.real_ptr_type.into()],
3760            None,
3761            false,
3762        );
3763        let mut block = self.start_function(function, None);
3764
3765        for (i, arg) in function.get_param_iter().enumerate() {
3766            let name = fn_arg_names[i];
3767            let alloca = self.function_arg_alloca(name, arg);
3768            self.insert_param(name, alloca);
3769        }
3770
3771        let mut id_index = 0usize;
3772        for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3773            let name = blk.name().unwrap_or("unknown");
3774            // loop thru the elements of this state blk and set the corresponding elements of id
3775            let id_start_index = self.int_type.const_int(id_index as u64, false);
3776            let blk_start_index = self.int_type.const_int(0, false);
3777            let blk_end_index = self
3778                .int_type
3779                .const_int(blk.nnz().try_into().unwrap(), false);
3780
3781            let blk_block = self.context.append_basic_block(function, name);
3782            self.builder.build_unconditional_branch(blk_block)?;
3783            self.builder.position_at_end(blk_block);
3784            let index = self.builder.build_phi(self.int_type, "i")?;
3785            index.add_incoming(&[(&blk_start_index, block)]);
3786
3787            // loop body - copy value from inputs to data
3788            let curr_blk_index = index.as_basic_value().into_int_value();
3789            let curr_id_index = self
3790                .builder
3791                .build_int_add(id_start_index, curr_blk_index, name)?;
3792            let id_ptr = Self::get_ptr_to_index(
3793                &self.builder,
3794                self.real_type,
3795                self.get_param("id"),
3796                curr_id_index,
3797                name,
3798            );
3799            let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
3800            let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3801            self.builder.build_store(id_ptr, is_algebraic_value)?;
3802
3803            // increment loop index
3804            let one = self.int_type.const_int(1, false);
3805            let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3806            index.add_incoming(&[(&next_index, blk_block)]);
3807
3808            // loop condition
3809            let loop_while = self.builder.build_int_compare(
3810                IntPredicate::ULT,
3811                next_index,
3812                blk_end_index,
3813                name,
3814            )?;
3815            let post_block = self.context.append_basic_block(function, name);
3816            self.builder
3817                .build_conditional_branch(loop_while, blk_block, post_block)?;
3818            self.builder.position_at_end(post_block);
3819
3820            // get ready for next blk
3821            block = post_block;
3822            id_index += blk.nnz();
3823        }
3824        self.builder.build_return(None)?;
3825
3826        if function.verify(true) {
3827            Ok(function)
3828        } else {
3829            function.print_to_stderr();
3830            unsafe {
3831                function.delete();
3832            }
3833            Err(anyhow!("Invalid generated function."))
3834        }
3835    }
3836
3837    pub fn module(&self) -> &Module<'ctx> {
3838        &self.module
3839    }
3840}