diffsl/execution/llvm/
codegen.rs

1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use inkwell::attributes::{Attribute, AttributeLoc};
4use inkwell::basic_block::BasicBlock;
5use inkwell::builder::Builder;
6use inkwell::context::{AsContextRef, Context};
7use inkwell::execution_engine::ExecutionEngine;
8use inkwell::intrinsics::Intrinsic;
9use inkwell::module::{Linkage, Module};
10use inkwell::passes::PassBuilderOptions;
11use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
12use inkwell::types::{
13    BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
14};
15use inkwell::values::{
16    AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
17    FunctionValue, GlobalValue, IntValue, PointerValue,
18};
19use inkwell::{
20    AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
21    OptimizationLevel,
22};
23use llvm_sys::core::{
24    LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
25    LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
26};
27use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
28use std::collections::HashMap;
29use std::ffi::CString;
30use std::iter::zip;
31use std::pin::Pin;
32use target_lexicon::Triple;
33
34type RealType = f64;
35
36use crate::ast::{Ast, AstKind};
37use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
38use crate::enzyme::{
39    CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Integer,
40    CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
41    CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
42    DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
43    EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
44    EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
45    FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
46    CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
47};
48use crate::execution::compiler::CompilerMode;
49use crate::execution::module::{
50    CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
51};
52use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
53use lazy_static::lazy_static;
54use std::sync::Mutex;
55
56lazy_static! {
57    static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
58}
59
60struct ImmovableLlvmModule {
61    // actually has lifetime of `context`
62    // declared first so it's droped before `context`
63    codegen: Option<CodeGen<'static>>,
64    // safety: we must never move out of this box as long as codgen is alive
65    context: AliasableBox<Context>,
66    _pin: std::marker::PhantomPinned,
67}
68
69pub struct LlvmModule {
70    inner: Pin<Box<ImmovableLlvmModule>>,
71    machine: TargetMachine,
72}
73
74unsafe impl Send for LlvmModule {}
75unsafe impl Sync for LlvmModule {}
76
77impl LlvmModule {
78    fn new(triple: Option<Triple>, model: &DiscreteModel, threaded: bool) -> Result<Self> {
79        let initialization_config = &InitializationConfig::default();
80        Target::initialize_all(initialization_config);
81        let host_triple = Triple::host();
82        let (triple_str, native) = match triple {
83            Some(ref triple) => (triple.to_string(), false),
84            None => (host_triple.to_string(), true),
85        };
86        let triple = TargetTriple::create(triple_str.as_str());
87        let target = Target::from_triple(&triple).unwrap();
88        let cpu = if native {
89            TargetMachine::get_host_cpu_name().to_string()
90        } else {
91            "generic".to_string()
92        };
93        let features = if native {
94            TargetMachine::get_host_cpu_features().to_string()
95        } else {
96            "".to_string()
97        };
98        let machine = target
99            .create_target_machine(
100                &triple,
101                cpu.as_str(),
102                features.as_str(),
103                inkwell::OptimizationLevel::Aggressive,
104                inkwell::targets::RelocMode::Default,
105                inkwell::targets::CodeModel::Default,
106            )
107            .unwrap();
108
109        let context = AliasableBox::from_unique(Box::new(Context::create()));
110        let mut pinned = Self {
111            inner: Box::pin(ImmovableLlvmModule {
112                codegen: None,
113                context,
114                _pin: std::marker::PhantomPinned,
115            }),
116            machine,
117        };
118
119        let context_ref = pinned.inner.context.as_ref();
120        let real_type_str = "f64";
121        let codegen = CodeGen::new(
122            model,
123            context_ref,
124            context_ref.f64_type(),
125            context_ref.i32_type(),
126            real_type_str,
127            threaded,
128        )?;
129        let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
130        unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
131        Ok(pinned)
132    }
133
134    fn pre_autodiff_optimisation(&mut self) -> Result<()> {
135        //let pass_manager_builder = PassManagerBuilder::create();
136        //pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default);
137        //let pass_manager = PassManager::create(());
138        //pass_manager_builder.populate_module_pass_manager(&pass_manager);
139        //pass_manager.run_on(self.codegen().module());
140
141        //self.codegen().module().print_to_stderr();
142        // optimise at -O2 no unrolling before giving to enzyme
143        let pass_options = PassBuilderOptions::create();
144        //pass_options.set_verify_each(true);
145        //pass_options.set_debug_logging(true);
146        //pass_options.set_loop_interleaving(true);
147        pass_options.set_loop_vectorization(false);
148        pass_options.set_loop_slp_vectorization(false);
149        pass_options.set_loop_unrolling(false);
150        //pass_options.set_forget_all_scev_in_loop_unroll(true);
151        //pass_options.set_licm_mssa_opt_cap(1);
152        //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10);
153        //pass_options.set_call_graph_profile(true);
154        //pass_options.set_merge_functions(true);
155
156        //let passes = "default<O2>";
157        let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
158        let (codegen, machine) = self.codegen_and_machine_mut();
159        codegen
160            .module()
161            .run_passes(passes, machine, pass_options)
162            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
163    }
164
165    fn post_autodiff_optimisation(&mut self) -> Result<()> {
166        // remove noinline attribute from barrier function as only needed for enzyme
167        if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
168            let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
169            barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
170        }
171
172        // remove all preprocess_* functions
173        for f in self.codegen_mut().module.get_functions() {
174            if f.get_name().to_str().unwrap().starts_with("preprocess_") {
175                unsafe { f.delete() };
176            }
177        }
178
179        //self.codegen()
180        //    .module()
181        //    .print_to_file("post_autodiff_optimisation.ll")
182        //    .unwrap();
183
184        let passes = "default<O3>";
185        let (codegen, machine) = self.codegen_and_machine_mut();
186        codegen
187            .module()
188            .run_passes(passes, machine, PassBuilderOptions::create())
189            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
190
191        Ok(())
192    }
193
194    pub fn print(&self) {
195        self.codegen().module().print_to_stderr();
196    }
197    fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
198        unsafe {
199            self.inner
200                .as_mut()
201                .get_unchecked_mut()
202                .codegen
203                .as_mut()
204                .unwrap()
205        }
206    }
207    fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
208        (
209            unsafe {
210                self.inner
211                    .as_mut()
212                    .get_unchecked_mut()
213                    .codegen
214                    .as_mut()
215                    .unwrap()
216            },
217            &self.machine,
218        )
219    }
220
221    fn codegen(&self) -> &CodeGen<'static> {
222        self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
223    }
224}
225
226impl CodegenModule for LlvmModule {}
227
228impl CodegenModuleCompile for LlvmModule {
229    fn from_discrete_model(
230        model: &DiscreteModel,
231        mode: CompilerMode,
232        triple: Option<Triple>,
233    ) -> Result<Self> {
234        let thread_dim = mode.thread_dim(model.state().nnz());
235        let threaded = thread_dim > 1;
236        if (unsafe { LLVMIsMultithreaded() } <= 0) {
237            return Err(anyhow!(
238                "LLVM is not compiled with multithreading support, but this codegen module requires it."
239            ));
240        }
241
242        let mut module = Self::new(triple, model, threaded)?;
243
244        let set_u0 = module.codegen_mut().compile_set_u0(model)?;
245        let _calc_stop = module.codegen_mut().compile_calc_stop(model)?;
246        let rhs = module.codegen_mut().compile_rhs(model, false)?;
247        let rhs_full = module.codegen_mut().compile_rhs(model, true)?;
248        let mass = module.codegen_mut().compile_mass(model)?;
249        let calc_out = module.codegen_mut().compile_calc_out(model, false)?;
250        let calc_out_full = module.codegen_mut().compile_calc_out(model, true)?;
251        let _set_id = module.codegen_mut().compile_set_id(model)?;
252        let _get_dims = module.codegen_mut().compile_get_dims(model)?;
253        let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
254        let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
255        let _set_constants = module.codegen_mut().compile_set_constants(model)?;
256        let tensor_info = module
257            .codegen()
258            .layout
259            .tensors()
260            .map(|(name, is_constant)| (name.to_string(), is_constant))
261            .collect::<Vec<_>>();
262        for (tensor, is_constant) in tensor_info {
263            if is_constant {
264                module
265                    .codegen_mut()
266                    .compile_get_constant(model, tensor.as_str())?;
267            } else {
268                module
269                    .codegen_mut()
270                    .compile_get_tensor(model, tensor.as_str())?;
271            }
272        }
273
274        module.pre_autodiff_optimisation()?;
275
276        module.codegen_mut().compile_gradient(
277            set_u0,
278            &[
279                CompileGradientArgType::DupNoNeed,
280                CompileGradientArgType::DupNoNeed,
281                CompileGradientArgType::Const,
282                CompileGradientArgType::Const,
283            ],
284            CompileMode::Forward,
285            "set_u0_grad",
286        )?;
287
288        module.codegen_mut().compile_gradient(
289            rhs,
290            &[
291                CompileGradientArgType::Const,
292                CompileGradientArgType::DupNoNeed,
293                CompileGradientArgType::DupNoNeed,
294                CompileGradientArgType::DupNoNeed,
295                CompileGradientArgType::Const,
296                CompileGradientArgType::Const,
297            ],
298            CompileMode::Forward,
299            "rhs_grad",
300        )?;
301
302        module.codegen_mut().compile_gradient(
303            calc_out,
304            &[
305                CompileGradientArgType::Const,
306                CompileGradientArgType::DupNoNeed,
307                CompileGradientArgType::DupNoNeed,
308                CompileGradientArgType::DupNoNeed,
309                CompileGradientArgType::Const,
310                CompileGradientArgType::Const,
311            ],
312            CompileMode::Forward,
313            "calc_out_grad",
314        )?;
315        module.codegen_mut().compile_gradient(
316            set_inputs,
317            &[
318                CompileGradientArgType::DupNoNeed,
319                CompileGradientArgType::DupNoNeed,
320            ],
321            CompileMode::Forward,
322            "set_inputs_grad",
323        )?;
324
325        module.codegen_mut().compile_gradient(
326            set_u0,
327            &[
328                CompileGradientArgType::DupNoNeed,
329                CompileGradientArgType::DupNoNeed,
330                CompileGradientArgType::Const,
331                CompileGradientArgType::Const,
332            ],
333            CompileMode::Reverse,
334            "set_u0_rgrad",
335        )?;
336
337        module.codegen_mut().compile_gradient(
338            mass,
339            &[
340                CompileGradientArgType::Const,
341                CompileGradientArgType::DupNoNeed,
342                CompileGradientArgType::DupNoNeed,
343                CompileGradientArgType::DupNoNeed,
344                CompileGradientArgType::Const,
345                CompileGradientArgType::Const,
346            ],
347            CompileMode::Reverse,
348            "mass_rgrad",
349        )?;
350
351        module.codegen_mut().compile_gradient(
352            rhs,
353            &[
354                CompileGradientArgType::Const,
355                CompileGradientArgType::DupNoNeed,
356                CompileGradientArgType::DupNoNeed,
357                CompileGradientArgType::DupNoNeed,
358                CompileGradientArgType::Const,
359                CompileGradientArgType::Const,
360            ],
361            CompileMode::Reverse,
362            "rhs_rgrad",
363        )?;
364        module.codegen_mut().compile_gradient(
365            calc_out,
366            &[
367                CompileGradientArgType::Const,
368                CompileGradientArgType::DupNoNeed,
369                CompileGradientArgType::DupNoNeed,
370                CompileGradientArgType::DupNoNeed,
371                CompileGradientArgType::Const,
372                CompileGradientArgType::Const,
373            ],
374            CompileMode::Reverse,
375            "calc_out_rgrad",
376        )?;
377
378        module.codegen_mut().compile_gradient(
379            set_inputs,
380            &[
381                CompileGradientArgType::DupNoNeed,
382                CompileGradientArgType::DupNoNeed,
383            ],
384            CompileMode::Reverse,
385            "set_inputs_rgrad",
386        )?;
387
388        module.codegen_mut().compile_gradient(
389            rhs_full,
390            &[
391                CompileGradientArgType::Const,
392                CompileGradientArgType::Const,
393                CompileGradientArgType::DupNoNeed,
394                CompileGradientArgType::DupNoNeed,
395                CompileGradientArgType::Const,
396                CompileGradientArgType::Const,
397            ],
398            CompileMode::ForwardSens,
399            "rhs_sgrad",
400        )?;
401
402        module.codegen_mut().compile_gradient(
403            set_u0,
404            &[
405                CompileGradientArgType::DupNoNeed,
406                CompileGradientArgType::DupNoNeed,
407                CompileGradientArgType::Const,
408                CompileGradientArgType::Const,
409            ],
410            CompileMode::ForwardSens,
411            "set_u0_sgrad",
412        )?;
413
414        module.codegen_mut().compile_gradient(
415            calc_out_full,
416            &[
417                CompileGradientArgType::Const,
418                CompileGradientArgType::Const,
419                CompileGradientArgType::DupNoNeed,
420                CompileGradientArgType::DupNoNeed,
421                CompileGradientArgType::Const,
422                CompileGradientArgType::Const,
423            ],
424            CompileMode::ForwardSens,
425            "calc_out_sgrad",
426        )?;
427        module.codegen_mut().compile_gradient(
428            calc_out_full,
429            &[
430                CompileGradientArgType::Const,
431                CompileGradientArgType::Const,
432                CompileGradientArgType::DupNoNeed,
433                CompileGradientArgType::DupNoNeed,
434                CompileGradientArgType::Const,
435                CompileGradientArgType::Const,
436            ],
437            CompileMode::ReverseSens,
438            "calc_out_srgrad",
439        )?;
440
441        module.codegen_mut().compile_gradient(
442            rhs_full,
443            &[
444                CompileGradientArgType::Const,
445                CompileGradientArgType::Const,
446                CompileGradientArgType::DupNoNeed,
447                CompileGradientArgType::DupNoNeed,
448                CompileGradientArgType::Const,
449                CompileGradientArgType::Const,
450            ],
451            CompileMode::ReverseSens,
452            "rhs_srgrad",
453        )?;
454
455        module.post_autodiff_optimisation()?;
456        Ok(module)
457    }
458}
459
460impl CodegenModuleEmit for LlvmModule {
461    fn to_object(self) -> Result<Vec<u8>> {
462        let module = self.codegen().module();
463        //module.print_to_stderr();
464        let buffer = self
465            .machine
466            .write_to_memory_buffer(module, FileType::Object)
467            .unwrap()
468            .as_slice()
469            .to_vec();
470        Ok(buffer)
471    }
472}
473impl CodegenModuleJit for LlvmModule {
474    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
475        let ee = self
476            .codegen()
477            .module()
478            .create_jit_execution_engine(OptimizationLevel::Aggressive)
479            .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
480
481        let module = self.codegen().module();
482        let mut symbols = HashMap::new();
483        for function in module.get_functions() {
484            let name = function.get_name().to_str().unwrap();
485            let address = ee.get_function_address(name);
486            if let Ok(address) = address {
487                symbols.insert(name.to_string(), address as *const u8);
488            }
489        }
490        Ok(symbols)
491    }
492}
493
494struct Globals<'ctx> {
495    indices: Option<GlobalValue<'ctx>>,
496    constants: Option<GlobalValue<'ctx>>,
497    thread_counter: Option<GlobalValue<'ctx>>,
498}
499
500impl<'ctx> Globals<'ctx> {
501    fn new(
502        layout: &DataLayout,
503        module: &Module<'ctx>,
504        int_type: IntType<'ctx>,
505        real_type: FloatType<'ctx>,
506        threaded: bool,
507    ) -> Self {
508        let thread_counter = if threaded {
509            let tc = module.add_global(
510                int_type,
511                Some(AddressSpace::default()),
512                "enzyme_const_thread_counter",
513            );
514            // todo: for some reason this doesn't make enzyme think it's inactive
515            // but using enzyme_const in the name does
516            // todo: also, adding this metadata causes the print of the module to segfault,
517            // so maybe a bug in inkwell
518            //let md_string = context.metadata_string("enzyme_inactive");
519            //tc.set_metadata(md_string, 0);
520            let tc_value = int_type.const_zero();
521            tc.set_visibility(GlobalVisibility::Hidden);
522            tc.set_initializer(&tc_value.as_basic_value_enum());
523            Some(tc)
524        } else {
525            None
526        };
527        let constants = if layout.constants().is_empty() {
528            None
529        } else {
530            let constants_array_type =
531                real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
532            let constants = module.add_global(
533                constants_array_type,
534                Some(AddressSpace::default()),
535                "enzyme_const_constants",
536            );
537            constants.set_visibility(GlobalVisibility::Hidden);
538            constants.set_constant(false);
539            constants.set_initializer(&constants_array_type.const_zero());
540            Some(constants)
541        };
542        let indices = if layout.indices().is_empty() {
543            None
544        } else {
545            let indices_array_type =
546                int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
547            let indices_array_values = layout
548                .indices()
549                .iter()
550                .map(|&i| int_type.const_int(i.try_into().unwrap(), false))
551                .collect::<Vec<IntValue>>();
552            let indices_value = int_type.const_array(indices_array_values.as_slice());
553            let indices = module.add_global(
554                indices_array_type,
555                Some(AddressSpace::default()),
556                "enzyme_const_indices",
557            );
558            indices.set_constant(true);
559            indices.set_visibility(GlobalVisibility::Hidden);
560            indices.set_initializer(&indices_value);
561            Some(indices)
562        };
563        Self {
564            indices,
565            thread_counter,
566            constants,
567        }
568    }
569}
570
571pub enum CompileGradientArgType {
572    Const,
573    Dup,
574    DupNoNeed,
575}
576
577pub enum CompileMode {
578    Forward,
579    ForwardSens,
580    Reverse,
581    ReverseSens,
582}
583
584pub struct CodeGen<'ctx> {
585    context: &'ctx inkwell::context::Context,
586    module: Module<'ctx>,
587    builder: Builder<'ctx>,
588    variables: HashMap<String, PointerValue<'ctx>>,
589    functions: HashMap<String, FunctionValue<'ctx>>,
590    fn_value_opt: Option<FunctionValue<'ctx>>,
591    tensor_ptr_opt: Option<PointerValue<'ctx>>,
592    real_type: FloatType<'ctx>,
593    real_ptr_type: PointerType<'ctx>,
594    real_type_str: String,
595    int_type: IntType<'ctx>,
596    int_ptr_type: PointerType<'ctx>,
597    layout: DataLayout,
598    globals: Globals<'ctx>,
599    threaded: bool,
600    _ee: Option<ExecutionEngine<'ctx>>,
601}
602
603unsafe extern "C" fn fwd_handler(
604    _builder: LLVMBuilderRef,
605    _call_instruction: LLVMValueRef,
606    _gutils: *mut GradientUtils,
607    _dcall: *mut LLVMValueRef,
608    _normal_return: *mut LLVMValueRef,
609    _shadow_return: *mut LLVMValueRef,
610) -> u8 {
611    1
612}
613
614unsafe extern "C" fn rev_handler(
615    builder: LLVMBuilderRef,
616    call_instruction: LLVMValueRef,
617    gutils: *mut DiffeGradientUtils,
618    _tape: LLVMValueRef,
619) {
620    let call_block = LLVMGetInstructionParent(call_instruction);
621    let call_function = LLVMGetBasicBlockParent(call_block);
622    let module = LLVMGetGlobalParent(call_function);
623    let name_c_str = CString::new("barrier_grad").unwrap();
624    let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
625    let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
626    let barrier_num = LLVMGetArgOperand(call_instruction, 0);
627    let total_barriers = LLVMGetArgOperand(call_instruction, 1);
628    let thread_count = LLVMGetArgOperand(call_instruction, 2);
629    let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
630    let total_barriers =
631        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
632    let thread_count =
633        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
634    let mut args = [barrier_num, total_barriers, thread_count];
635    let name_c_str = CString::new("").unwrap();
636    LLVMBuildCall2(
637        builder,
638        barrier_func_type,
639        barrier_func,
640        args.as_mut_ptr(),
641        args.len() as u32,
642        name_c_str.as_ptr(),
643    );
644}
645
646#[allow(dead_code)]
647enum PrintValue<'ctx> {
648    Real(FloatValue<'ctx>),
649    Int(IntValue<'ctx>),
650}
651
652impl<'ctx> CodeGen<'ctx> {
653    pub fn new(
654        model: &DiscreteModel,
655        context: &'ctx inkwell::context::Context,
656        real_type: FloatType<'ctx>,
657        int_type: IntType<'ctx>,
658        real_type_str: &str,
659        threaded: bool,
660    ) -> Result<Self> {
661        let builder = context.create_builder();
662        let layout = DataLayout::new(model);
663        let module = context.create_module(model.name());
664        let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
665        let real_ptr_type = Self::pointer_type(context, real_type.into());
666        let int_ptr_type = Self::pointer_type(context, int_type.into());
667        let mut ret = Self {
668            context,
669            module,
670            builder,
671            real_type,
672            real_ptr_type,
673            real_type_str: real_type_str.to_owned(),
674            variables: HashMap::new(),
675            functions: HashMap::new(),
676            fn_value_opt: None,
677            tensor_ptr_opt: None,
678            layout,
679            int_type,
680            int_ptr_type,
681            globals,
682            threaded,
683            _ee: None,
684        };
685        if threaded {
686            ret.compile_barrier_init()?;
687            ret.compile_barrier()?;
688            ret.compile_barrier_grad()?;
689            // todo: think I can remove this unless I want to call enzyme using a llvm pass
690            //ret.globals.add_registered_barrier(ret.context, &ret.module);
691        }
692        Ok(ret)
693    }
694
695    #[allow(dead_code)]
696    fn compile_print_value(
697        &mut self,
698        name: &str,
699        value: PrintValue<'ctx>,
700    ) -> Result<CallSiteValue<'_>> {
701        let void_type = self.context.void_type();
702        // int printf(const char *format, ...)
703        let printf_type = void_type.fn_type(&[self.int_ptr_type.into()], true);
704        // get printf function or declare it if it doesn't exist
705        let printf = match self.module.get_function("printf") {
706            Some(f) => f,
707            None => self
708                .module
709                .add_function("printf", printf_type, Some(Linkage::External)),
710        };
711        let (format_str, format_str_name) = match value {
712            PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
713            PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
714        };
715        // change format_str to c string
716        let format_str = CString::new(format_str).unwrap();
717        // if format_str_name doesn not already exist as a global, add it
718        let format_str_global = match self.module.get_global(format_str_name.as_str()) {
719            Some(g) => g,
720            None => {
721                let format_str = self.context.const_string(format_str.as_bytes(), true);
722                let fmt_str =
723                    self.module
724                        .add_global(format_str.get_type(), None, format_str_name.as_str());
725                fmt_str.set_initializer(&format_str);
726                fmt_str.set_visibility(GlobalVisibility::Hidden);
727                fmt_str
728            }
729        };
730        // call printf with the format string and the value
731        let format_str_ptr = self.builder.build_pointer_cast(
732            format_str_global.as_pointer_value(),
733            self.int_ptr_type,
734            "format_str_ptr",
735        )?;
736        let value: BasicMetadataValueEnum = match value {
737            PrintValue::Real(v) => v.into(),
738            PrintValue::Int(v) => v.into(),
739        };
740        self.builder
741            .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
742            .map_err(|e| anyhow!("Error building call to printf: {}", e))
743    }
744
745    fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
746        self.clear();
747        let void_type = self.context.void_type();
748        let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false);
749        let fn_arg_names = &["thread_id", "thread_dim"];
750        let function = self.module.add_function("set_constants", fn_type, None);
751
752        let basic_block = self.context.append_basic_block(function, "entry");
753        self.fn_value_opt = Some(function);
754        self.builder.position_at_end(basic_block);
755
756        for (i, arg) in function.get_param_iter().enumerate() {
757            let name = fn_arg_names[i];
758            let alloca = self.function_arg_alloca(name, arg);
759            self.insert_param(name, alloca);
760        }
761
762        self.insert_indices();
763        self.insert_constants(model);
764
765        let mut nbarriers = 0;
766        let total_barriers = (model.constant_defns().len()) as u64;
767        #[allow(clippy::explicit_counter_loop)]
768        for a in model.constant_defns() {
769            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
770            self.jit_compile_call_barrier(nbarriers, total_barriers);
771            nbarriers += 1;
772        }
773
774        self.builder.build_return(None)?;
775
776        if function.verify(true) {
777            Ok(function)
778        } else {
779            function.print_to_stderr();
780            unsafe {
781                function.delete();
782            }
783            Err(anyhow!("Invalid generated function."))
784        }
785    }
786
787    fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
788        self.clear();
789        let void_type = self.context.void_type();
790        let fn_type = void_type.fn_type(&[], false);
791        let function = self.module.add_function("barrier_init", fn_type, None);
792
793        let entry_block = self.context.append_basic_block(function, "entry");
794
795        self.fn_value_opt = Some(function);
796        self.builder.position_at_end(entry_block);
797
798        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
799        self.builder
800            .build_store(thread_counter, self.int_type.const_zero())?;
801
802        self.builder.build_return(None)?;
803
804        if function.verify(true) {
805            self.functions.insert("barrier_init".to_owned(), function);
806            Ok(function)
807        } else {
808            function.print_to_stderr();
809            unsafe {
810                function.delete();
811            }
812            Err(anyhow!("Invalid generated function."))
813        }
814    }
815
816    fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
817        self.clear();
818        let void_type = self.context.void_type();
819        let fn_type = void_type.fn_type(
820            &[
821                self.int_type.into(),
822                self.int_type.into(),
823                self.int_type.into(),
824            ],
825            false,
826        );
827        let function = self.module.add_function("barrier", fn_type, None);
828        let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
829        let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
830        function.add_attribute(AttributeLoc::Function, noinline);
831
832        let entry_block = self.context.append_basic_block(function, "entry");
833        let increment_block = self.context.append_basic_block(function, "increment");
834        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
835        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
836
837        self.fn_value_opt = Some(function);
838        self.builder.position_at_end(entry_block);
839
840        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
841        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
842        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
843        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
844
845        let nbarrier_equals_total_barriers = self
846            .builder
847            .build_int_compare(
848                IntPredicate::EQ,
849                barrier_num,
850                total_barriers,
851                "nbarrier_equals_total_barriers",
852            )
853            .unwrap();
854        // branch to barrier_done if nbarrier == total_barriers
855        self.builder.build_conditional_branch(
856            nbarrier_equals_total_barriers,
857            barrier_done_block,
858            increment_block,
859        )?;
860        self.builder.position_at_end(increment_block);
861
862        let barrier_num_times_thread_count = self
863            .builder
864            .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
865            .unwrap();
866
867        // Atomically increment the barrier counter
868        let i32_type = self.context.i32_type();
869        let one = i32_type.const_int(1, false);
870        self.builder.build_atomicrmw(
871            AtomicRMWBinOp::Add,
872            thread_counter,
873            one,
874            AtomicOrdering::Monotonic,
875        )?;
876
877        // wait_loop:
878        self.builder.build_unconditional_branch(wait_loop_block)?;
879        self.builder.position_at_end(wait_loop_block);
880
881        let current_value = self
882            .builder
883            .build_load(i32_type, thread_counter, "current_value")?
884            .into_int_value();
885
886        current_value
887            .as_instruction_value()
888            .unwrap()
889            .set_atomic_ordering(AtomicOrdering::Monotonic)
890            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
891
892        let all_threads_done = self.builder.build_int_compare(
893            IntPredicate::UGE,
894            current_value,
895            barrier_num_times_thread_count,
896            "all_threads_done",
897        )?;
898
899        self.builder.build_conditional_branch(
900            all_threads_done,
901            barrier_done_block,
902            wait_loop_block,
903        )?;
904        self.builder.position_at_end(barrier_done_block);
905
906        self.builder.build_return(None)?;
907
908        if function.verify(true) {
909            self.functions.insert("barrier".to_owned(), function);
910            Ok(function)
911        } else {
912            function.print_to_stderr();
913            unsafe {
914                function.delete();
915            }
916            Err(anyhow!("Invalid generated function."))
917        }
918    }
919
920    fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
921        self.clear();
922        let void_type = self.context.void_type();
923        let fn_type = void_type.fn_type(
924            &[
925                self.int_type.into(),
926                self.int_type.into(),
927                self.int_type.into(),
928            ],
929            false,
930        );
931        let function = self.module.add_function("barrier_grad", fn_type, None);
932
933        let entry_block = self.context.append_basic_block(function, "entry");
934        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
935        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
936
937        self.fn_value_opt = Some(function);
938        self.builder.position_at_end(entry_block);
939
940        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
941        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
942        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
943        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
944
945        let twice_total_barriers = self
946            .builder
947            .build_int_mul(
948                total_barriers,
949                self.int_type.const_int(2, false),
950                "twice_total_barriers",
951            )
952            .unwrap();
953        let twice_total_barriers_minus_barrier_num = self
954            .builder
955            .build_int_sub(
956                twice_total_barriers,
957                barrier_num,
958                "twice_total_barriers_minus_barrier_num",
959            )
960            .unwrap();
961        let twice_total_barriers_minus_barrier_num_times_thread_count = self
962            .builder
963            .build_int_mul(
964                twice_total_barriers_minus_barrier_num,
965                thread_count,
966                "twice_total_barriers_minus_barrier_num_times_thread_count",
967            )
968            .unwrap();
969
970        // Atomically increment the barrier counter
971        let i32_type = self.context.i32_type();
972        let one = i32_type.const_int(1, false);
973        self.builder.build_atomicrmw(
974            AtomicRMWBinOp::Add,
975            thread_counter,
976            one,
977            AtomicOrdering::Monotonic,
978        )?;
979
980        // wait_loop:
981        self.builder.build_unconditional_branch(wait_loop_block)?;
982        self.builder.position_at_end(wait_loop_block);
983
984        let current_value = self
985            .builder
986            .build_load(i32_type, thread_counter, "current_value")?
987            .into_int_value();
988        current_value
989            .as_instruction_value()
990            .unwrap()
991            .set_atomic_ordering(AtomicOrdering::Monotonic)
992            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
993
994        let all_threads_done = self.builder.build_int_compare(
995            IntPredicate::UGE,
996            current_value,
997            twice_total_barriers_minus_barrier_num_times_thread_count,
998            "all_threads_done",
999        )?;
1000
1001        self.builder.build_conditional_branch(
1002            all_threads_done,
1003            barrier_done_block,
1004            wait_loop_block,
1005        )?;
1006        self.builder.position_at_end(barrier_done_block);
1007
1008        self.builder.build_return(None)?;
1009
1010        if function.verify(true) {
1011            self.functions.insert("barrier_grad".to_owned(), function);
1012            Ok(function)
1013        } else {
1014            function.print_to_stderr();
1015            unsafe {
1016                function.delete();
1017            }
1018            Err(anyhow!("Invalid generated function."))
1019        }
1020    }
1021
1022    fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1023        if !self.threaded {
1024            return;
1025        }
1026        let thread_dim = self.get_param("thread_dim");
1027        let thread_dim = self
1028            .builder
1029            .build_load(self.int_type, *thread_dim, "thread_dim")
1030            .unwrap()
1031            .into_int_value();
1032        let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1033        let total_barriers = self.int_type.const_int(total_barriers, false);
1034        let barrier = self.get_function("barrier").unwrap();
1035        self.builder
1036            .build_call(
1037                barrier,
1038                &[
1039                    BasicMetadataValueEnum::IntValue(nbarrier),
1040                    BasicMetadataValueEnum::IntValue(total_barriers),
1041                    BasicMetadataValueEnum::IntValue(thread_dim),
1042                ],
1043                "barrier",
1044            )
1045            .unwrap();
1046    }
1047
1048    fn jit_threading_limits(
1049        &mut self,
1050        size: IntValue<'ctx>,
1051    ) -> Result<(
1052        IntValue<'ctx>,
1053        IntValue<'ctx>,
1054        BasicBlock<'ctx>,
1055        BasicBlock<'ctx>,
1056    )> {
1057        let one = self.int_type.const_int(1, false);
1058        let thread_id = self.get_param("thread_id");
1059        let thread_id = self
1060            .builder
1061            .build_load(self.int_type, *thread_id, "thread_id")
1062            .unwrap()
1063            .into_int_value();
1064        let thread_dim = self.get_param("thread_dim");
1065        let thread_dim = self
1066            .builder
1067            .build_load(self.int_type, *thread_dim, "thread_dim")
1068            .unwrap()
1069            .into_int_value();
1070
1071        // start index is i * size / thread_dim
1072        let i_times_size = self
1073            .builder
1074            .build_int_mul(thread_id, size, "i_times_size")?;
1075        let start = self
1076            .builder
1077            .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1078
1079        // the ending index for thread i is (i+1) * size / thread_dim
1080        let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1081        let i_plus_one_times_size =
1082            self.builder
1083                .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1084        let end = self
1085            .builder
1086            .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1087
1088        let test_done = self.builder.get_insert_block().unwrap();
1089        let next_block = self
1090            .context
1091            .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1092        self.builder.position_at_end(next_block);
1093
1094        Ok((start, end, test_done, next_block))
1095    }
1096
1097    fn jit_end_threading(
1098        &mut self,
1099        start: IntValue<'ctx>,
1100        end: IntValue<'ctx>,
1101        test_done: BasicBlock<'ctx>,
1102        next: BasicBlock<'ctx>,
1103    ) -> Result<()> {
1104        let exit = self
1105            .context
1106            .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1107        self.builder.build_unconditional_branch(exit)?;
1108        self.builder.position_at_end(test_done);
1109        // done if start == end
1110        let done = self
1111            .builder
1112            .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1113        self.builder.build_conditional_branch(done, exit, next)?;
1114        self.builder.position_at_end(exit);
1115        Ok(())
1116    }
1117
1118    pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1119        self.module.write_bitcode_to_path(path);
1120    }
1121
1122    fn insert_constants(&mut self, model: &DiscreteModel) {
1123        if let Some(constants) = self.globals.constants.as_ref() {
1124            self.insert_param("constants", constants.as_pointer_value());
1125            for tensor in model.constant_defns() {
1126                self.insert_tensor(tensor, true);
1127            }
1128        }
1129    }
1130
1131    fn insert_data(&mut self, model: &DiscreteModel) {
1132        self.insert_constants(model);
1133
1134        for tensor in model.inputs() {
1135            self.insert_tensor(tensor, false);
1136        }
1137        for tensor in model.input_dep_defns() {
1138            self.insert_tensor(tensor, false);
1139        }
1140        for tensor in model.time_dep_defns() {
1141            self.insert_tensor(tensor, false);
1142        }
1143        for tensor in model.state_dep_defns() {
1144            self.insert_tensor(tensor, false);
1145        }
1146    }
1147
1148    fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1149        context.ptr_type(AddressSpace::default())
1150    }
1151
1152    fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1153        context.ptr_type(AddressSpace::default())
1154    }
1155
1156    fn insert_indices(&mut self) {
1157        if let Some(indices) = self.globals.indices.as_ref() {
1158            let i32_type = self.context.i32_type();
1159            let zero = i32_type.const_int(0, false);
1160            let ptr = unsafe {
1161                indices
1162                    .as_pointer_value()
1163                    .const_in_bounds_gep(i32_type, &[zero])
1164            };
1165            self.variables.insert("indices".to_owned(), ptr);
1166        }
1167    }
1168
1169    fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1170        self.variables.insert(name.to_owned(), value);
1171    }
1172
1173    fn build_gep<T: BasicType<'ctx>>(
1174        &self,
1175        ty: T,
1176        ptr: PointerValue<'ctx>,
1177        ordered_indexes: &[IntValue<'ctx>],
1178        name: &str,
1179    ) -> Result<PointerValue<'ctx>> {
1180        unsafe {
1181            self.builder
1182                .build_gep(ty, ptr, ordered_indexes, name)
1183                .map_err(|e| e.into())
1184        }
1185    }
1186
1187    fn build_load<T: BasicType<'ctx>>(
1188        &self,
1189        ty: T,
1190        ptr: PointerValue<'ctx>,
1191        name: &str,
1192    ) -> Result<BasicValueEnum<'ctx>> {
1193        self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1194    }
1195
1196    fn get_ptr_to_index<T: BasicType<'ctx>>(
1197        builder: &Builder<'ctx>,
1198        ty: T,
1199        ptr: &PointerValue<'ctx>,
1200        index: IntValue<'ctx>,
1201        name: &str,
1202    ) -> PointerValue<'ctx> {
1203        unsafe {
1204            builder
1205                .build_in_bounds_gep(ty, *ptr, &[index], name)
1206                .unwrap()
1207        }
1208    }
1209
1210    fn insert_state(&mut self, u: &Tensor) {
1211        let mut data_index = 0;
1212        for blk in u.elmts() {
1213            if let Some(name) = blk.name() {
1214                let ptr = self.variables.get("u").unwrap();
1215                let i = self
1216                    .context
1217                    .i32_type()
1218                    .const_int(data_index.try_into().unwrap(), false);
1219                let alloca = Self::get_ptr_to_index(
1220                    &self.create_entry_block_builder(),
1221                    self.real_type,
1222                    ptr,
1223                    i,
1224                    blk.name().unwrap(),
1225                );
1226                self.variables.insert(name.to_owned(), alloca);
1227            }
1228            data_index += blk.nnz();
1229        }
1230    }
1231    fn insert_dot_state(&mut self, dudt: &Tensor) {
1232        let mut data_index = 0;
1233        for blk in dudt.elmts() {
1234            if let Some(name) = blk.name() {
1235                let ptr = self.variables.get("dudt").unwrap();
1236                let i = self
1237                    .context
1238                    .i32_type()
1239                    .const_int(data_index.try_into().unwrap(), false);
1240                let alloca = Self::get_ptr_to_index(
1241                    &self.create_entry_block_builder(),
1242                    self.real_type,
1243                    ptr,
1244                    i,
1245                    blk.name().unwrap(),
1246                );
1247                self.variables.insert(name.to_owned(), alloca);
1248            }
1249            data_index += blk.nnz();
1250        }
1251    }
1252    fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1253        let var_name = if is_constant { "constants" } else { "data" };
1254        let ptr = *self.variables.get(var_name).unwrap();
1255        let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1256        let i = self
1257            .context
1258            .i32_type()
1259            .const_int(data_index.try_into().unwrap(), false);
1260        let alloca = Self::get_ptr_to_index(
1261            &self.create_entry_block_builder(),
1262            self.real_type,
1263            &ptr,
1264            i,
1265            tensor.name(),
1266        );
1267        self.variables.insert(tensor.name().to_owned(), alloca);
1268
1269        //insert any named blocks
1270        for blk in tensor.elmts() {
1271            if let Some(name) = blk.name() {
1272                let i = self
1273                    .context
1274                    .i32_type()
1275                    .const_int(data_index.try_into().unwrap(), false);
1276                let alloca = Self::get_ptr_to_index(
1277                    &self.create_entry_block_builder(),
1278                    self.real_type,
1279                    &ptr,
1280                    i,
1281                    name,
1282                );
1283                self.variables.insert(name.to_owned(), alloca);
1284            }
1285            // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index
1286            data_index += blk.nnz();
1287        }
1288    }
1289    fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1290        self.variables.get(name).unwrap()
1291    }
1292
1293    fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1294        self.variables.get(tensor.name()).unwrap()
1295    }
1296
1297    fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1298        match self.functions.get(name) {
1299            Some(&func) => Some(func),
1300            None => {
1301                let function = match name {
1302                    // support some llvm intrinsics
1303                    "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs"
1304                    | "copysign" | "pow" | "min" | "max" => {
1305                        let arg_len = 1;
1306                        let intrinsic_name = match name {
1307                            "min" => "minnum",
1308                            "max" => "maxnum",
1309                            "abs" => "fabs",
1310                            _ => name,
1311                        };
1312                        let llvm_name = format!("llvm.{}.{}", intrinsic_name, self.real_type_str);
1313                        let intrinsic = Intrinsic::find(&llvm_name).unwrap();
1314                        let ret_type = self.real_type;
1315
1316                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1317                            .map(|f| f.into())
1318                            .collect::<Vec<BasicTypeEnum>>();
1319                        // if we get an intrinsic, we don't need to add to the list of functions and can return early
1320                        return intrinsic.get_declaration(&self.module, args_types.as_slice());
1321                    }
1322                    // some custom functions
1323                    "sigmoid" => {
1324                        let arg_len = 1;
1325                        let ret_type = self.real_type;
1326
1327                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1328                            .map(|f| f.into())
1329                            .collect::<Vec<BasicMetadataTypeEnum>>();
1330
1331                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1332                        let fn_val = self.module.add_function(name, fn_type, None);
1333
1334                        for arg in fn_val.get_param_iter() {
1335                            arg.into_float_value().set_name("x");
1336                        }
1337
1338                        let current_block = self.builder.get_insert_block().unwrap();
1339                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1340                        self.builder.position_at_end(basic_block);
1341                        let x = fn_val.get_nth_param(0)?.into_float_value();
1342                        let one = self.real_type.const_float(1.0);
1343                        let negx = self.builder.build_float_neg(x, name).ok()?;
1344                        let exp = self.get_function("exp").unwrap();
1345                        let exp_negx = self
1346                            .builder
1347                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1348                            .ok()?;
1349                        let one_plus_exp_negx = self
1350                            .builder
1351                            .build_float_add(
1352                                exp_negx
1353                                    .try_as_basic_value()
1354                                    .left()
1355                                    .unwrap()
1356                                    .into_float_value(),
1357                                one,
1358                                name,
1359                            )
1360                            .ok()?;
1361                        let sigmoid = self
1362                            .builder
1363                            .build_float_div(one, one_plus_exp_negx, name)
1364                            .ok()?;
1365                        self.builder.build_return(Some(&sigmoid)).ok();
1366                        self.builder.position_at_end(current_block);
1367                        Some(fn_val)
1368                    }
1369                    "arcsinh" | "arccosh" => {
1370                        let arg_len = 1;
1371                        let ret_type = self.real_type;
1372
1373                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1374                            .map(|f| f.into())
1375                            .collect::<Vec<BasicMetadataTypeEnum>>();
1376
1377                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1378                        let fn_val = self.module.add_function(name, fn_type, None);
1379
1380                        for arg in fn_val.get_param_iter() {
1381                            arg.into_float_value().set_name("x");
1382                        }
1383
1384                        let current_block = self.builder.get_insert_block().unwrap();
1385                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1386                        self.builder.position_at_end(basic_block);
1387                        let x = fn_val.get_nth_param(0)?.into_float_value();
1388                        let one = match name {
1389                            "arccosh" => self.real_type.const_float(-1.0),
1390                            "arcsinh" => self.real_type.const_float(1.0),
1391                            _ => panic!("unknown function"),
1392                        };
1393                        let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1394                        let one_plus_x_squared =
1395                            self.builder.build_float_add(x_squared, one, name).ok()?;
1396                        let sqrt = self.get_function("sqrt").unwrap();
1397                        let sqrt_one_plus_x_squared = self
1398                            .builder
1399                            .build_call(
1400                                sqrt,
1401                                &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1402                                name,
1403                            )
1404                            .unwrap()
1405                            .try_as_basic_value()
1406                            .left()
1407                            .unwrap()
1408                            .into_float_value();
1409                        let x_plus_sqrt_one_plus_x_squared = self
1410                            .builder
1411                            .build_float_add(x, sqrt_one_plus_x_squared, name)
1412                            .ok()?;
1413                        let ln = self.get_function("log").unwrap();
1414                        let result = self
1415                            .builder
1416                            .build_call(
1417                                ln,
1418                                &[BasicMetadataValueEnum::FloatValue(
1419                                    x_plus_sqrt_one_plus_x_squared,
1420                                )],
1421                                name,
1422                            )
1423                            .unwrap()
1424                            .try_as_basic_value()
1425                            .left()
1426                            .unwrap()
1427                            .into_float_value();
1428                        self.builder.build_return(Some(&result)).ok();
1429                        self.builder.position_at_end(current_block);
1430                        Some(fn_val)
1431                    }
1432                    "heaviside" => {
1433                        let arg_len = 1;
1434                        let ret_type = self.real_type;
1435
1436                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1437                            .map(|f| f.into())
1438                            .collect::<Vec<BasicMetadataTypeEnum>>();
1439
1440                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1441                        let fn_val = self.module.add_function(name, fn_type, None);
1442
1443                        for arg in fn_val.get_param_iter() {
1444                            arg.into_float_value().set_name("x");
1445                        }
1446
1447                        let current_block = self.builder.get_insert_block().unwrap();
1448                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1449                        self.builder.position_at_end(basic_block);
1450                        let x = fn_val.get_nth_param(0)?.into_float_value();
1451                        let zero = self.real_type.const_float(0.0);
1452                        let one = self.real_type.const_float(1.0);
1453                        let result = self
1454                            .builder
1455                            .build_select(
1456                                self.builder
1457                                    .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1458                                    .unwrap(),
1459                                one,
1460                                zero,
1461                                name,
1462                            )
1463                            .ok()?;
1464                        self.builder.build_return(Some(&result)).ok();
1465                        self.builder.position_at_end(current_block);
1466                        Some(fn_val)
1467                    }
1468                    "tanh" | "sinh" | "cosh" => {
1469                        let arg_len = 1;
1470                        let ret_type = self.real_type;
1471
1472                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1473                            .map(|f| f.into())
1474                            .collect::<Vec<BasicMetadataTypeEnum>>();
1475
1476                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1477                        let fn_val = self.module.add_function(name, fn_type, None);
1478
1479                        for arg in fn_val.get_param_iter() {
1480                            arg.into_float_value().set_name("x");
1481                        }
1482
1483                        let current_block = self.builder.get_insert_block().unwrap();
1484                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1485                        self.builder.position_at_end(basic_block);
1486                        let x = fn_val.get_nth_param(0)?.into_float_value();
1487                        let negx = self.builder.build_float_neg(x, name).ok()?;
1488                        let exp = self.get_function("exp").unwrap();
1489                        let exp_negx = self
1490                            .builder
1491                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1492                            .ok()?;
1493                        let expx = self
1494                            .builder
1495                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1496                            .ok()?;
1497                        let expx_minus_exp_negx = self
1498                            .builder
1499                            .build_float_sub(
1500                                expx.try_as_basic_value().left().unwrap().into_float_value(),
1501                                exp_negx
1502                                    .try_as_basic_value()
1503                                    .left()
1504                                    .unwrap()
1505                                    .into_float_value(),
1506                                name,
1507                            )
1508                            .ok()?;
1509                        let expx_plus_exp_negx = self
1510                            .builder
1511                            .build_float_add(
1512                                expx.try_as_basic_value().left().unwrap().into_float_value(),
1513                                exp_negx
1514                                    .try_as_basic_value()
1515                                    .left()
1516                                    .unwrap()
1517                                    .into_float_value(),
1518                                name,
1519                            )
1520                            .ok()?;
1521                        let result = match name {
1522                            "tanh" => self
1523                                .builder
1524                                .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1525                                .ok()?,
1526                            "sinh" => self
1527                                .builder
1528                                .build_float_div(
1529                                    expx_minus_exp_negx,
1530                                    self.real_type.const_float(2.0),
1531                                    name,
1532                                )
1533                                .ok()?,
1534                            "cosh" => self
1535                                .builder
1536                                .build_float_div(
1537                                    expx_plus_exp_negx,
1538                                    self.real_type.const_float(2.0),
1539                                    name,
1540                                )
1541                                .ok()?,
1542                            _ => panic!("unknown function"),
1543                        };
1544                        self.builder.build_return(Some(&result)).ok();
1545                        self.builder.position_at_end(current_block);
1546                        Some(fn_val)
1547                    }
1548                    _ => None,
1549                }?;
1550                self.functions.insert(name.to_owned(), function);
1551                Some(function)
1552            }
1553        }
1554    }
1555    /// Returns the `FunctionValue` representing the function being compiled.
1556    #[inline]
1557    fn fn_value(&self) -> FunctionValue<'ctx> {
1558        self.fn_value_opt.unwrap()
1559    }
1560
1561    #[inline]
1562    fn tensor_ptr(&self) -> PointerValue<'ctx> {
1563        self.tensor_ptr_opt.unwrap()
1564    }
1565
1566    /// Creates a new builder in the entry block of the function.
1567    fn create_entry_block_builder(&self) -> Builder<'ctx> {
1568        let builder = self.context.create_builder();
1569        let entry = self.fn_value().get_first_basic_block().unwrap();
1570        match entry.get_first_instruction() {
1571            Some(first_instr) => builder.position_before(&first_instr),
1572            None => builder.position_at_end(entry),
1573        }
1574        builder
1575    }
1576
1577    fn jit_compile_scalar(
1578        &mut self,
1579        a: &Tensor,
1580        res_ptr_opt: Option<PointerValue<'ctx>>,
1581    ) -> Result<PointerValue<'ctx>> {
1582        let res_type = self.real_type;
1583        let res_ptr = match res_ptr_opt {
1584            Some(ptr) => ptr,
1585            None => self
1586                .create_entry_block_builder()
1587                .build_alloca(res_type, a.name())?,
1588        };
1589        let name = a.name();
1590        let elmt = a.elmts().first().unwrap();
1591
1592        // if threaded then only the first thread will evaluate the scalar
1593        let curr_block = self.builder.get_insert_block().unwrap();
1594        let mut next_block_opt = None;
1595        if self.threaded {
1596            let next_block = self.context.append_basic_block(self.fn_value(), "next");
1597            self.builder.position_at_end(next_block);
1598            next_block_opt = Some(next_block);
1599        }
1600
1601        let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, None)?;
1602        self.builder.build_store(res_ptr, float_value)?;
1603
1604        // complete the threading block
1605        if self.threaded {
1606            let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1607            self.builder.build_unconditional_branch(exit_block)?;
1608            self.builder.position_at_end(curr_block);
1609
1610            let thread_id = self.get_param("thread_id");
1611            let thread_id = self
1612                .builder
1613                .build_load(self.int_type, *thread_id, "thread_id")
1614                .unwrap()
1615                .into_int_value();
1616            let is_first_thread = self.builder.build_int_compare(
1617                IntPredicate::EQ,
1618                thread_id,
1619                self.int_type.const_zero(),
1620                "is_first_thread",
1621            )?;
1622            self.builder.build_conditional_branch(
1623                is_first_thread,
1624                next_block_opt.unwrap(),
1625                exit_block,
1626            )?;
1627
1628            self.builder.position_at_end(exit_block);
1629        }
1630
1631        Ok(res_ptr)
1632    }
1633
1634    fn jit_compile_tensor(
1635        &mut self,
1636        a: &Tensor,
1637        res_ptr_opt: Option<PointerValue<'ctx>>,
1638    ) -> Result<PointerValue<'ctx>> {
1639        // treat scalar as a special case
1640        if a.rank() == 0 {
1641            return self.jit_compile_scalar(a, res_ptr_opt);
1642        }
1643
1644        let res_type = self.real_type;
1645        let res_ptr = match res_ptr_opt {
1646            Some(ptr) => ptr,
1647            None => self
1648                .create_entry_block_builder()
1649                .build_alloca(res_type, a.name())?,
1650        };
1651
1652        // set up the tensor storage pointer and index into this data
1653        self.tensor_ptr_opt = Some(res_ptr);
1654
1655        for (i, blk) in a.elmts().iter().enumerate() {
1656            let default = format!("{}-{}", a.name(), i);
1657            let name = blk.name().unwrap_or(default.as_str());
1658            self.jit_compile_block(name, a, blk)?;
1659        }
1660        Ok(res_ptr)
1661    }
1662
1663    fn jit_compile_block(&mut self, name: &str, tensor: &Tensor, elmt: &TensorBlock) -> Result<()> {
1664        let translation = Translation::new(
1665            elmt.expr_layout(),
1666            elmt.layout(),
1667            elmt.start(),
1668            tensor.layout_ptr(),
1669        );
1670
1671        if elmt.expr_layout().is_dense() {
1672            self.jit_compile_dense_block(name, elmt, &translation)
1673        } else if elmt.expr_layout().is_diagonal() {
1674            self.jit_compile_diagonal_block(name, elmt, &translation)
1675        } else if elmt.expr_layout().is_sparse() {
1676            match translation.source {
1677                TranslationFrom::SparseContraction { .. } => {
1678                    self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1679                }
1680                _ => self.jit_compile_sparse_block(name, elmt, &translation),
1681            }
1682        } else {
1683            Err(anyhow!(
1684                "unsupported block layout: {:?}",
1685                elmt.expr_layout()
1686            ))
1687        }
1688    }
1689
1690    // for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index
1691    fn jit_compile_dense_block(
1692        &mut self,
1693        name: &str,
1694        elmt: &TensorBlock,
1695        translation: &Translation,
1696    ) -> Result<()> {
1697        let int_type = self.int_type;
1698
1699        let mut preblock = self.builder.get_insert_block().unwrap();
1700        let expr_rank = elmt.expr_layout().rank();
1701        let expr_shape = elmt
1702            .expr_layout()
1703            .shape()
1704            .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1705        let one = int_type.const_int(1, false);
1706
1707        let mut expr_strides = vec![1; expr_rank];
1708        if expr_rank > 0 {
1709            for i in (0..expr_rank - 1).rev() {
1710                expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1711            }
1712        }
1713        let expr_strides = expr_strides
1714            .iter()
1715            .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1716            .collect::<Vec<IntValue>>();
1717
1718        // setup indices, loop through the nested loops
1719        let mut indices = Vec::new();
1720        let mut blocks = Vec::new();
1721
1722        // allocate the contract sum if needed
1723        let (contract_sum, contract_by, contract_strides) =
1724            if let TranslationFrom::DenseContraction {
1725                contract_by,
1726                contract_len: _,
1727            } = translation.source
1728            {
1729                let contract_rank = expr_rank - contract_by;
1730                let mut contract_strides = vec![1; contract_rank];
1731                for i in (0..contract_rank - 1).rev() {
1732                    contract_strides[i] =
1733                        contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1734                }
1735                let contract_strides = contract_strides
1736                    .iter()
1737                    .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1738                    .collect::<Vec<IntValue>>();
1739                (
1740                    Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1741                    contract_by,
1742                    Some(contract_strides),
1743                )
1744            } else {
1745                (None, 0, None)
1746            };
1747
1748        // we will thread the output loop, except if we are contracting to a scalar
1749        let (thread_start, thread_end, test_done, next) = if self.threaded {
1750            let (start, end, test_done, next) =
1751                self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1752            preblock = next;
1753            (Some(start), Some(end), Some(test_done), Some(next))
1754        } else {
1755            (None, None, None, None)
1756        };
1757
1758        for i in 0..expr_rank {
1759            let block = self.context.append_basic_block(self.fn_value(), name);
1760            self.builder.build_unconditional_branch(block)?;
1761            self.builder.position_at_end(block);
1762
1763            let start_index = if i == 0 && self.threaded {
1764                thread_start.unwrap()
1765            } else {
1766                self.int_type.const_zero()
1767            };
1768
1769            let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1770            curr_index.add_incoming(&[(&start_index, preblock)]);
1771
1772            if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1773                self.builder
1774                    .build_store(contract_sum.unwrap(), self.real_type.const_zero())?;
1775            }
1776
1777            indices.push(curr_index);
1778            blocks.push(block);
1779            preblock = block;
1780        }
1781
1782        let indices_int: Vec<IntValue> = indices
1783            .iter()
1784            .map(|i| i.as_basic_value().into_int_value())
1785            .collect();
1786
1787        let float_value =
1788            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, None)?;
1789
1790        if contract_sum.is_some() {
1791            let contract_sum_value = self
1792                .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1793                .into_float_value();
1794            let new_contract_sum_value = self.builder.build_float_add(
1795                contract_sum_value,
1796                float_value,
1797                "new_contract_sum",
1798            )?;
1799            self.builder
1800                .build_store(contract_sum.unwrap(), new_contract_sum_value)?;
1801        } else {
1802            let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
1803                self.int_type.const_zero(),
1804                |acc, (i, s)| {
1805                    let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
1806                    self.builder.build_int_add(acc, tmp, "acc").unwrap()
1807                },
1808            );
1809            preblock = self.jit_compile_broadcast_and_store(
1810                name,
1811                elmt,
1812                float_value,
1813                expr_index,
1814                translation,
1815                preblock,
1816            )?;
1817        }
1818
1819        // unwind the nested loops
1820        for i in (0..expr_rank).rev() {
1821            // increment index
1822            let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
1823            indices[i].add_incoming(&[(&next_index, preblock)]);
1824
1825            if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1826                let contract_sum_value = self
1827                    .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1828                    .into_float_value();
1829                let contract_strides = contract_strides.as_ref().unwrap();
1830                let elmt_index = indices_int
1831                    .iter()
1832                    .take(contract_strides.len())
1833                    .zip(contract_strides.iter())
1834                    .fold(self.int_type.const_zero(), |acc, (i, s)| {
1835                        let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
1836                        self.builder.build_int_add(acc, tmp, "acc").unwrap()
1837                    });
1838                self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?;
1839            }
1840
1841            let end_index = if i == 0 && self.threaded {
1842                thread_end.unwrap()
1843            } else {
1844                expr_shape[i]
1845            };
1846
1847            // loop condition
1848            let loop_while =
1849                self.builder
1850                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
1851            let block = self.context.append_basic_block(self.fn_value(), name);
1852            self.builder
1853                .build_conditional_branch(loop_while, blocks[i], block)?;
1854            self.builder.position_at_end(block);
1855            preblock = block;
1856        }
1857
1858        if self.threaded {
1859            self.jit_end_threading(
1860                thread_start.unwrap(),
1861                thread_end.unwrap(),
1862                test_done.unwrap(),
1863                next.unwrap(),
1864            )?;
1865        }
1866        Ok(())
1867    }
1868
1869    fn jit_compile_sparse_contraction_block(
1870        &mut self,
1871        name: &str,
1872        elmt: &TensorBlock,
1873        translation: &Translation,
1874    ) -> Result<()> {
1875        match translation.source {
1876            TranslationFrom::SparseContraction { .. } => {}
1877            _ => {
1878                panic!("expected sparse contraction")
1879            }
1880        }
1881        let int_type = self.int_type;
1882
1883        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
1884        let translation_index = self
1885            .layout
1886            .get_translation_index(elmt.expr_layout(), elmt.layout())
1887            .unwrap();
1888        let translation_index = translation_index + translation.get_from_index_in_data_layout();
1889
1890        let final_contract_index =
1891            int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
1892        let (thread_start, thread_end, test_done, next) = if self.threaded {
1893            let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
1894            (Some(start), Some(end), Some(test_done), Some(next))
1895        } else {
1896            (None, None, None, None)
1897        };
1898
1899        let preblock = self.builder.get_insert_block().unwrap();
1900        let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
1901
1902        // loop through each contraction
1903        let block = self.context.append_basic_block(self.fn_value(), name);
1904        self.builder.build_unconditional_branch(block)?;
1905        self.builder.position_at_end(block);
1906
1907        let contract_index = self.builder.build_phi(int_type, "i")?;
1908        let contract_start = if self.threaded {
1909            thread_start.unwrap()
1910        } else {
1911            int_type.const_zero()
1912        };
1913        contract_index.add_incoming(&[(&contract_start, preblock)]);
1914
1915        let start_index = self.builder.build_int_add(
1916            int_type.const_int(translation_index.try_into().unwrap(), false),
1917            self.builder.build_int_mul(
1918                int_type.const_int(2, false),
1919                contract_index.as_basic_value().into_int_value(),
1920                name,
1921            )?,
1922            name,
1923        )?;
1924        let end_index =
1925            self.builder
1926                .build_int_add(start_index, int_type.const_int(1, false), name)?;
1927        let start_ptr = self.build_gep(
1928            self.int_type,
1929            *self.get_param("indices"),
1930            &[start_index],
1931            "start_index_ptr",
1932        )?;
1933        let start_contract = self
1934            .build_load(self.int_type, start_ptr, "start")?
1935            .into_int_value();
1936        let end_ptr = self.build_gep(
1937            self.int_type,
1938            *self.get_param("indices"),
1939            &[end_index],
1940            "end_index_ptr",
1941        )?;
1942        let end_contract = self
1943            .build_load(self.int_type, end_ptr, "end")?
1944            .into_int_value();
1945
1946        // initialise the contract sum
1947        self.builder
1948            .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
1949
1950        // loop through each element in the contraction
1951        let contract_block = self
1952            .context
1953            .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
1954        self.builder.build_unconditional_branch(contract_block)?;
1955        self.builder.position_at_end(contract_block);
1956
1957        let expr_index_phi = self.builder.build_phi(int_type, "j")?;
1958        expr_index_phi.add_incoming(&[(&start_contract, block)]);
1959
1960        // loop body - load index from layout
1961        let expr_index = expr_index_phi.as_basic_value().into_int_value();
1962        let elmt_index_mult_rank = self.builder.build_int_mul(
1963            expr_index,
1964            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
1965            name,
1966        )?;
1967        let indices_int = (0..elmt.expr_layout().rank())
1968            .map(|i| {
1969                let layout_index_plus_offset =
1970                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
1971                let curr_index = self.builder.build_int_add(
1972                    elmt_index_mult_rank,
1973                    layout_index_plus_offset,
1974                    name,
1975                )?;
1976                let ptr = Self::get_ptr_to_index(
1977                    &self.builder,
1978                    self.int_type,
1979                    self.get_param("indices"),
1980                    curr_index,
1981                    name,
1982                );
1983                let index = self.build_load(self.int_type, ptr, name)?.into_int_value();
1984                Ok(index)
1985            })
1986            .collect::<Result<Vec<_>, anyhow::Error>>()?;
1987
1988        // loop body - eval expression and increment sum
1989        let float_value = self.jit_compile_expr(
1990            name,
1991            elmt.expr(),
1992            indices_int.as_slice(),
1993            elmt,
1994            Some(expr_index),
1995        )?;
1996        let contract_sum_value = self
1997            .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
1998            .into_float_value();
1999        let new_contract_sum_value =
2000            self.builder
2001                .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
2002        self.builder
2003            .build_store(contract_sum_ptr, new_contract_sum_value)?;
2004
2005        // increment contract loop index
2006        let next_elmt_index =
2007            self.builder
2008                .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2009        expr_index_phi.add_incoming(&[(&next_elmt_index, contract_block)]);
2010
2011        // contract loop condition
2012        let loop_while = self.builder.build_int_compare(
2013            IntPredicate::ULT,
2014            next_elmt_index,
2015            end_contract,
2016            name,
2017        )?;
2018        let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2019        self.builder
2020            .build_conditional_branch(loop_while, contract_block, post_contract_block)?;
2021        self.builder.position_at_end(post_contract_block);
2022
2023        // store the result
2024        self.jit_compile_store(
2025            name,
2026            elmt,
2027            contract_index.as_basic_value().into_int_value(),
2028            new_contract_sum_value,
2029            translation,
2030        )?;
2031
2032        // increment outer loop index
2033        let next_contract_index = self.builder.build_int_add(
2034            contract_index.as_basic_value().into_int_value(),
2035            int_type.const_int(1, false),
2036            name,
2037        )?;
2038        contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2039
2040        // outer loop condition
2041        let loop_while = self.builder.build_int_compare(
2042            IntPredicate::ULT,
2043            next_contract_index,
2044            thread_end.unwrap_or(final_contract_index),
2045            name,
2046        )?;
2047        let post_block = self.context.append_basic_block(self.fn_value(), name);
2048        self.builder
2049            .build_conditional_branch(loop_while, block, post_block)?;
2050        self.builder.position_at_end(post_block);
2051
2052        if self.threaded {
2053            self.jit_end_threading(
2054                thread_start.unwrap(),
2055                thread_end.unwrap(),
2056                test_done.unwrap(),
2057                next.unwrap(),
2058            )?;
2059        }
2060
2061        Ok(())
2062    }
2063
2064    // for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index
2065    // TODO: havn't implemented contractions yet
2066    fn jit_compile_sparse_block(
2067        &mut self,
2068        name: &str,
2069        elmt: &TensorBlock,
2070        translation: &Translation,
2071    ) -> Result<()> {
2072        let int_type = self.int_type;
2073
2074        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2075        let start_index = int_type.const_int(0, false);
2076        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2077
2078        let (thread_start, thread_end, test_done, next) = if self.threaded {
2079            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2080            (Some(start), Some(end), Some(test_done), Some(next))
2081        } else {
2082            (None, None, None, None)
2083        };
2084
2085        // loop through the non-zero elements
2086        let preblock = self.builder.get_insert_block().unwrap();
2087        let mut block = self.context.append_basic_block(self.fn_value(), name);
2088        self.builder.build_unconditional_branch(block)?;
2089        self.builder.position_at_end(block);
2090
2091        let curr_index = self.builder.build_phi(int_type, "i")?;
2092        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2093
2094        // loop body - load index from layout
2095        let elmt_index = curr_index.as_basic_value().into_int_value();
2096        let elmt_index_mult_rank = self.builder.build_int_mul(
2097            elmt_index,
2098            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2099            name,
2100        )?;
2101        let indices_int = (0..elmt.expr_layout().rank())
2102            .map(|i| {
2103                let layout_index_plus_offset =
2104                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
2105                let curr_index = self.builder.build_int_add(
2106                    elmt_index_mult_rank,
2107                    layout_index_plus_offset,
2108                    name,
2109                )?;
2110                let ptr = Self::get_ptr_to_index(
2111                    &self.builder,
2112                    self.int_type,
2113                    self.get_param("indices"),
2114                    curr_index,
2115                    name,
2116                );
2117                Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2118            })
2119            .collect::<Result<Vec<_>, anyhow::Error>>()?;
2120
2121        // loop body - eval expression
2122        let float_value = self.jit_compile_expr(
2123            name,
2124            elmt.expr(),
2125            indices_int.as_slice(),
2126            elmt,
2127            Some(elmt_index),
2128        )?;
2129
2130        block = self.jit_compile_broadcast_and_store(
2131            name,
2132            elmt,
2133            float_value,
2134            elmt_index,
2135            translation,
2136            block,
2137        )?;
2138
2139        // increment loop index
2140        let one = int_type.const_int(1, false);
2141        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2142        curr_index.add_incoming(&[(&next_index, block)]);
2143
2144        // loop condition
2145        let loop_while = self.builder.build_int_compare(
2146            IntPredicate::ULT,
2147            next_index,
2148            thread_end.unwrap_or(end_index),
2149            name,
2150        )?;
2151        let post_block = self.context.append_basic_block(self.fn_value(), name);
2152        self.builder
2153            .build_conditional_branch(loop_while, block, post_block)?;
2154        self.builder.position_at_end(post_block);
2155
2156        if self.threaded {
2157            self.jit_end_threading(
2158                thread_start.unwrap(),
2159                thread_end.unwrap(),
2160                test_done.unwrap(),
2161                next.unwrap(),
2162            )?;
2163        }
2164
2165        Ok(())
2166    }
2167
2168    // for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index
2169    fn jit_compile_diagonal_block(
2170        &mut self,
2171        name: &str,
2172        elmt: &TensorBlock,
2173        translation: &Translation,
2174    ) -> Result<()> {
2175        let int_type = self.int_type;
2176
2177        let start_index = int_type.const_int(0, false);
2178        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2179
2180        let (thread_start, thread_end, test_done, next) = if self.threaded {
2181            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2182            (Some(start), Some(end), Some(test_done), Some(next))
2183        } else {
2184            (None, None, None, None)
2185        };
2186
2187        // loop through the non-zero elements
2188        let preblock = self.builder.get_insert_block().unwrap();
2189        let mut block = self.context.append_basic_block(self.fn_value(), name);
2190        self.builder.build_unconditional_branch(block)?;
2191        self.builder.position_at_end(block);
2192
2193        let curr_index = self.builder.build_phi(int_type, "i")?;
2194        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2195
2196        // loop body - index is just the same for each element
2197        let elmt_index = curr_index.as_basic_value().into_int_value();
2198        let indices_int: Vec<IntValue> =
2199            (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2200
2201        // loop body - eval expression
2202        let float_value = self.jit_compile_expr(
2203            name,
2204            elmt.expr(),
2205            indices_int.as_slice(),
2206            elmt,
2207            Some(elmt_index),
2208        )?;
2209
2210        // loop body - store result
2211        block = self.jit_compile_broadcast_and_store(
2212            name,
2213            elmt,
2214            float_value,
2215            elmt_index,
2216            translation,
2217            block,
2218        )?;
2219
2220        // increment loop index
2221        let one = int_type.const_int(1, false);
2222        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2223        curr_index.add_incoming(&[(&next_index, block)]);
2224
2225        // loop condition
2226        let loop_while = self.builder.build_int_compare(
2227            IntPredicate::ULT,
2228            next_index,
2229            thread_end.unwrap_or(end_index),
2230            name,
2231        )?;
2232        let post_block = self.context.append_basic_block(self.fn_value(), name);
2233        self.builder
2234            .build_conditional_branch(loop_while, block, post_block)?;
2235        self.builder.position_at_end(post_block);
2236
2237        if self.threaded {
2238            self.jit_end_threading(
2239                thread_start.unwrap(),
2240                thread_end.unwrap(),
2241                test_done.unwrap(),
2242                next.unwrap(),
2243            )?;
2244        }
2245
2246        Ok(())
2247    }
2248
2249    fn jit_compile_broadcast_and_store(
2250        &mut self,
2251        name: &str,
2252        elmt: &TensorBlock,
2253        float_value: FloatValue<'ctx>,
2254        expr_index: IntValue<'ctx>,
2255        translation: &Translation,
2256        pre_block: BasicBlock<'ctx>,
2257    ) -> Result<BasicBlock<'ctx>> {
2258        let int_type = self.int_type;
2259        let one = int_type.const_int(1, false);
2260        let zero = int_type.const_int(0, false);
2261        match translation.source {
2262            TranslationFrom::Broadcast {
2263                broadcast_by: _,
2264                broadcast_len,
2265            } => {
2266                let bcast_start_index = zero;
2267                let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2268
2269                // setup loop block
2270                let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2271                self.builder.build_unconditional_branch(bcast_block)?;
2272                self.builder.position_at_end(bcast_block);
2273                let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2274                bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2275
2276                // store value
2277                let store_index = self.builder.build_int_add(
2278                    self.builder
2279                        .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2280                    bcast_index.as_basic_value().into_int_value(),
2281                    "bcast_store_index",
2282                )?;
2283                self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2284
2285                // increment index
2286                let bcast_next_index = self.builder.build_int_add(
2287                    bcast_index.as_basic_value().into_int_value(),
2288                    one,
2289                    name,
2290                )?;
2291                bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2292
2293                // loop condition
2294                let bcast_cond = self.builder.build_int_compare(
2295                    IntPredicate::ULT,
2296                    bcast_next_index,
2297                    bcast_end_index,
2298                    "broadcast_cond",
2299                )?;
2300                let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2301                self.builder
2302                    .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2303                self.builder.position_at_end(post_bcast_block);
2304
2305                // return the current block for later
2306                Ok(post_bcast_block)
2307            }
2308            TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2309                self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2310                Ok(pre_block)
2311            }
2312            _ => Err(anyhow!("Invalid translation")),
2313        }
2314    }
2315
2316    fn jit_compile_store(
2317        &mut self,
2318        name: &str,
2319        elmt: &TensorBlock,
2320        store_index: IntValue<'ctx>,
2321        float_value: FloatValue<'ctx>,
2322        translation: &Translation,
2323    ) -> Result<()> {
2324        let int_type = self.int_type;
2325        let res_index = match &translation.target {
2326            TranslationTo::Contiguous { start, end: _ } => {
2327                let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2328                self.builder.build_int_add(start_const, store_index, name)?
2329            }
2330            TranslationTo::Sparse { indices: _ } => {
2331                // load store index from layout
2332                let translate_index = self
2333                    .layout
2334                    .get_translation_index(elmt.expr_layout(), elmt.layout())
2335                    .unwrap();
2336                let translate_store_index =
2337                    translate_index + translation.get_to_index_in_data_layout();
2338                let translate_store_index =
2339                    int_type.const_int(translate_store_index.try_into().unwrap(), false);
2340                let elmt_index_strided = store_index;
2341                let curr_index =
2342                    self.builder
2343                        .build_int_add(elmt_index_strided, translate_store_index, name)?;
2344                let ptr = Self::get_ptr_to_index(
2345                    &self.builder,
2346                    self.int_type,
2347                    self.get_param("indices"),
2348                    curr_index,
2349                    name,
2350                );
2351                self.build_load(self.int_type, ptr, name)?.into_int_value()
2352            }
2353        };
2354
2355        let resi_ptr = Self::get_ptr_to_index(
2356            &self.builder,
2357            self.real_type,
2358            &self.tensor_ptr(),
2359            res_index,
2360            name,
2361        );
2362        self.builder.build_store(resi_ptr, float_value)?;
2363        Ok(())
2364    }
2365
2366    fn jit_compile_expr(
2367        &mut self,
2368        name: &str,
2369        expr: &Ast,
2370        index: &[IntValue<'ctx>],
2371        elmt: &TensorBlock,
2372        expr_index: Option<IntValue<'ctx>>,
2373    ) -> Result<FloatValue<'ctx>> {
2374        let name = elmt.name().unwrap_or(name);
2375        match &expr.kind {
2376            AstKind::Binop(binop) => {
2377                let lhs =
2378                    self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2379                let rhs =
2380                    self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2381                match binop.op {
2382                    '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2383                    '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2384                    '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2385                    '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2386                    unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2387                }
2388            }
2389            AstKind::Monop(monop) => {
2390                let child =
2391                    self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2392                match monop.op {
2393                    '-' => Ok(self.builder.build_float_neg(child, name)?),
2394                    unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2395                }
2396            }
2397            AstKind::Call(call) => match self.get_function(call.fn_name) {
2398                Some(function) => {
2399                    let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2400                    for arg in call.args.iter() {
2401                        let arg_val =
2402                            self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2403                        args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2404                    }
2405                    let ret_value = self
2406                        .builder
2407                        .build_call(function, args.as_slice(), name)?
2408                        .try_as_basic_value()
2409                        .left()
2410                        .unwrap()
2411                        .into_float_value();
2412                    Ok(ret_value)
2413                }
2414                None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2415            },
2416            AstKind::CallArg(arg) => {
2417                self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2418            }
2419            AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2420            AstKind::Name(iname) => {
2421                let ptr = self.get_param(iname.name);
2422                let layout = self.layout.get_layout(iname.name).unwrap();
2423                let iname_elmt_index = if layout.is_dense() {
2424                    // permute indices based on the index chars of this tensor
2425                    let mut no_transform = true;
2426                    let mut iname_index = Vec::new();
2427                    for (i, c) in iname.indices.iter().enumerate() {
2428                        // find the position index of this index char in the tensor's index chars,
2429                        // if it's not found then it must be a contraction index so is at the end
2430                        let pi = elmt
2431                            .indices()
2432                            .iter()
2433                            .position(|x| x == c)
2434                            .unwrap_or(elmt.indices().len());
2435                        iname_index.push(index[pi]);
2436                        no_transform = no_transform && pi == i;
2437                    }
2438                    // calculate the element index using iname_index and the shape of the tensor
2439                    // TODO: can we optimise this by using expr_index, and also including elmt_index?
2440                    if !iname_index.is_empty() {
2441                        let mut iname_elmt_index = *iname_index.last().unwrap();
2442                        let mut stride = 1u64;
2443                        for i in (0..iname_index.len() - 1).rev() {
2444                            let iname_i = iname_index[i];
2445                            let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2446                            stride *= shapei;
2447                            let stride_intval = self.context.i32_type().const_int(stride, false);
2448                            let stride_mul_i =
2449                                self.builder.build_int_mul(stride_intval, iname_i, name)?;
2450                            iname_elmt_index =
2451                                self.builder
2452                                    .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2453                        }
2454                        Some(iname_elmt_index)
2455                    } else {
2456                        let zero = self.context.i32_type().const_int(0, false);
2457                        Some(zero)
2458                    }
2459                } else if layout.is_sparse() || layout.is_diagonal() {
2460                    // must have come from jit_compile_sparse_block, so we can just use the elmt_index
2461                    // must have come from jit_compile_diagonal_block, so we can just use the elmt_index
2462                    expr_index
2463                } else {
2464                    panic!("unexpected layout");
2465                };
2466                let value_ptr = match iname_elmt_index {
2467                    Some(index) => {
2468                        Self::get_ptr_to_index(&self.builder, self.real_type, ptr, index, name)
2469                    }
2470                    None => *ptr,
2471                };
2472                Ok(self
2473                    .build_load(self.real_type, value_ptr, name)?
2474                    .into_float_value())
2475            }
2476            AstKind::NamedGradient(name) => {
2477                let name_str = name.to_string();
2478                let ptr = self.get_param(name_str.as_str());
2479                Ok(self
2480                    .build_load(self.real_type, *ptr, name_str.as_str())?
2481                    .into_float_value())
2482            }
2483            AstKind::Index(_) => todo!(),
2484            AstKind::Slice(_) => todo!(),
2485            AstKind::Integer(_) => todo!(),
2486            _ => panic!("unexprected astkind"),
2487        }
2488    }
2489
2490    fn clear(&mut self) {
2491        self.variables.clear();
2492        //self.functions.clear();
2493        self.fn_value_opt = None;
2494        self.tensor_ptr_opt = None;
2495    }
2496
2497    fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2498        match arg {
2499            BasicValueEnum::PointerValue(v) => v,
2500            BasicValueEnum::FloatValue(v) => {
2501                let alloca = self
2502                    .create_entry_block_builder()
2503                    .build_alloca(arg.get_type(), name)
2504                    .unwrap();
2505                self.builder.build_store(alloca, v).unwrap();
2506                alloca
2507            }
2508            BasicValueEnum::IntValue(v) => {
2509                let alloca = self
2510                    .create_entry_block_builder()
2511                    .build_alloca(arg.get_type(), name)
2512                    .unwrap();
2513                self.builder.build_store(alloca, v).unwrap();
2514                alloca
2515            }
2516            _ => unreachable!(),
2517        }
2518    }
2519
2520    pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2521        self.clear();
2522        let void_type = self.context.void_type();
2523        let fn_type = void_type.fn_type(
2524            &[
2525                self.real_ptr_type.into(),
2526                self.real_ptr_type.into(),
2527                self.int_type.into(),
2528                self.int_type.into(),
2529            ],
2530            false,
2531        );
2532        let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2533        let function = self.module.add_function("set_u0", fn_type, None);
2534
2535        // add noalias
2536        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2537        let noalign = self.context.create_enum_attribute(alias_id, 0);
2538        for i in &[0, 1] {
2539            function.add_attribute(AttributeLoc::Param(*i), noalign);
2540        }
2541
2542        let basic_block = self.context.append_basic_block(function, "entry");
2543        self.fn_value_opt = Some(function);
2544        self.builder.position_at_end(basic_block);
2545
2546        for (i, arg) in function.get_param_iter().enumerate() {
2547            let name = fn_arg_names[i];
2548            let alloca = self.function_arg_alloca(name, arg);
2549            self.insert_param(name, alloca);
2550        }
2551
2552        self.insert_data(model);
2553        self.insert_indices();
2554
2555        let mut nbarriers = 0;
2556        let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2557        #[allow(clippy::explicit_counter_loop)]
2558        for a in model.input_dep_defns() {
2559            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2560            self.jit_compile_call_barrier(nbarriers, total_barriers);
2561            nbarriers += 1;
2562        }
2563
2564        self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?;
2565        self.jit_compile_call_barrier(nbarriers, total_barriers);
2566
2567        self.builder.build_return(None)?;
2568
2569        if function.verify(true) {
2570            Ok(function)
2571        } else {
2572            function.print_to_stderr();
2573            unsafe {
2574                function.delete();
2575            }
2576            Err(anyhow!("Invalid generated function."))
2577        }
2578    }
2579
2580    pub fn compile_calc_out<'m>(
2581        &mut self,
2582        model: &'m DiscreteModel,
2583        include_constants: bool,
2584    ) -> Result<FunctionValue<'ctx>> {
2585        self.clear();
2586        let void_type = self.context.void_type();
2587        let fn_type = void_type.fn_type(
2588            &[
2589                self.real_type.into(),
2590                self.real_ptr_type.into(),
2591                self.real_ptr_type.into(),
2592                self.real_ptr_type.into(),
2593                self.int_type.into(),
2594                self.int_type.into(),
2595            ],
2596            false,
2597        );
2598        let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2599        let function_name = if include_constants {
2600            "calc_out_full"
2601        } else {
2602            "calc_out"
2603        };
2604        let function = self.module.add_function(function_name, fn_type, None);
2605
2606        // add noalias
2607        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2608        let noalign = self.context.create_enum_attribute(alias_id, 0);
2609        for i in &[1, 2] {
2610            function.add_attribute(AttributeLoc::Param(*i), noalign);
2611        }
2612
2613        let basic_block = self.context.append_basic_block(function, "entry");
2614        self.fn_value_opt = Some(function);
2615        self.builder.position_at_end(basic_block);
2616
2617        for (i, arg) in function.get_param_iter().enumerate() {
2618            let name = fn_arg_names[i];
2619            let alloca = self.function_arg_alloca(name, arg);
2620            self.insert_param(name, alloca);
2621        }
2622
2623        self.insert_state(model.state());
2624        self.insert_data(model);
2625        self.insert_indices();
2626
2627        // print thread_id and thread_dim
2628        //let thread_id = function.get_nth_param(3).unwrap();
2629        //let thread_dim = function.get_nth_param(4).unwrap();
2630        //self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?;
2631        //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?;
2632        if let Some(out) = model.out() {
2633            let mut nbarriers = 0;
2634            let mut total_barriers =
2635                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2636            if include_constants {
2637                total_barriers += model.input_dep_defns().len() as u64;
2638                // calculate time independant definitions
2639                for tensor in model.input_dep_defns() {
2640                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2641                    self.jit_compile_call_barrier(nbarriers, total_barriers);
2642                    nbarriers += 1;
2643                }
2644            }
2645
2646            // calculate time dependant definitions
2647            for tensor in model.time_dep_defns() {
2648                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2649                self.jit_compile_call_barrier(nbarriers, total_barriers);
2650                nbarriers += 1;
2651            }
2652
2653            // calculate state dependant definitions
2654            #[allow(clippy::explicit_counter_loop)]
2655            for a in model.state_dep_defns() {
2656                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2657                self.jit_compile_call_barrier(nbarriers, total_barriers);
2658                nbarriers += 1;
2659            }
2660
2661            self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())))?;
2662            self.jit_compile_call_barrier(nbarriers, total_barriers);
2663        }
2664        self.builder.build_return(None)?;
2665
2666        if function.verify(true) {
2667            Ok(function)
2668        } else {
2669            function.print_to_stderr();
2670            unsafe {
2671                function.delete();
2672            }
2673            Err(anyhow!("Invalid generated function."))
2674        }
2675    }
2676
2677    pub fn compile_calc_stop<'m>(
2678        &mut self,
2679        model: &'m DiscreteModel,
2680    ) -> Result<FunctionValue<'ctx>> {
2681        self.clear();
2682        let void_type = self.context.void_type();
2683        let fn_type = void_type.fn_type(
2684            &[
2685                self.real_type.into(),
2686                self.real_ptr_type.into(),
2687                self.real_ptr_type.into(),
2688                self.real_ptr_type.into(),
2689                self.int_type.into(),
2690                self.int_type.into(),
2691            ],
2692            false,
2693        );
2694        let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2695        let function = self.module.add_function("calc_stop", fn_type, None);
2696
2697        // add noalias
2698        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2699        let noalign = self.context.create_enum_attribute(alias_id, 0);
2700        for i in &[1, 2, 3] {
2701            function.add_attribute(AttributeLoc::Param(*i), noalign);
2702        }
2703
2704        let basic_block = self.context.append_basic_block(function, "entry");
2705        self.fn_value_opt = Some(function);
2706        self.builder.position_at_end(basic_block);
2707
2708        for (i, arg) in function.get_param_iter().enumerate() {
2709            let name = fn_arg_names[i];
2710            let alloca = self.function_arg_alloca(name, arg);
2711            self.insert_param(name, alloca);
2712        }
2713
2714        self.insert_state(model.state());
2715        self.insert_data(model);
2716        self.insert_indices();
2717
2718        if let Some(stop) = model.stop() {
2719            // calculate time dependant definitions
2720            let mut nbarriers = 0;
2721            let total_barriers =
2722                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2723            for tensor in model.time_dep_defns() {
2724                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2725                self.jit_compile_call_barrier(nbarriers, total_barriers);
2726                nbarriers += 1;
2727            }
2728
2729            // calculate state dependant definitions
2730            for a in model.state_dep_defns() {
2731                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2732                self.jit_compile_call_barrier(nbarriers, total_barriers);
2733                nbarriers += 1;
2734            }
2735
2736            let res_ptr = self.get_param("root");
2737            self.jit_compile_tensor(stop, Some(*res_ptr))?;
2738            self.jit_compile_call_barrier(nbarriers, total_barriers);
2739        }
2740        self.builder.build_return(None)?;
2741
2742        if function.verify(true) {
2743            Ok(function)
2744        } else {
2745            function.print_to_stderr();
2746            unsafe {
2747                function.delete();
2748            }
2749            Err(anyhow!("Invalid generated function."))
2750        }
2751    }
2752
2753    pub fn compile_rhs<'m>(
2754        &mut self,
2755        model: &'m DiscreteModel,
2756        include_constants: bool,
2757    ) -> Result<FunctionValue<'ctx>> {
2758        self.clear();
2759        let void_type = self.context.void_type();
2760        let fn_type = void_type.fn_type(
2761            &[
2762                self.real_type.into(),
2763                self.real_ptr_type.into(),
2764                self.real_ptr_type.into(),
2765                self.real_ptr_type.into(),
2766                self.int_type.into(),
2767                self.int_type.into(),
2768            ],
2769            false,
2770        );
2771        let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
2772        let function_name = if include_constants { "rhs_full" } else { "rhs" };
2773        let function = self.module.add_function(function_name, fn_type, None);
2774
2775        // add noalias
2776        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2777        let noalign = self.context.create_enum_attribute(alias_id, 0);
2778        for i in &[1, 2, 3] {
2779            function.add_attribute(AttributeLoc::Param(*i), noalign);
2780        }
2781
2782        let basic_block = self.context.append_basic_block(function, "entry");
2783        self.fn_value_opt = Some(function);
2784        self.builder.position_at_end(basic_block);
2785
2786        for (i, arg) in function.get_param_iter().enumerate() {
2787            let name = fn_arg_names[i];
2788            let alloca = self.function_arg_alloca(name, arg);
2789            self.insert_param(name, alloca);
2790        }
2791
2792        self.insert_state(model.state());
2793        self.insert_data(model);
2794        self.insert_indices();
2795
2796        let mut nbarriers = 0;
2797        let mut total_barriers =
2798            (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2799        if include_constants {
2800            total_barriers += model.input_dep_defns().len() as u64;
2801            // calculate constant definitions
2802            for tensor in model.input_dep_defns() {
2803                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2804                self.jit_compile_call_barrier(nbarriers, total_barriers);
2805                nbarriers += 1;
2806            }
2807        }
2808
2809        // calculate time dependant definitions
2810        for tensor in model.time_dep_defns() {
2811            self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2812            self.jit_compile_call_barrier(nbarriers, total_barriers);
2813            nbarriers += 1;
2814        }
2815
2816        // TODO: could split state dep defns into before and after F
2817        for a in model.state_dep_defns() {
2818            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2819            self.jit_compile_call_barrier(nbarriers, total_barriers);
2820            nbarriers += 1;
2821        }
2822
2823        // F
2824        let res_ptr = self.get_param("rr");
2825        self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?;
2826        self.jit_compile_call_barrier(nbarriers, total_barriers);
2827
2828        self.builder.build_return(None)?;
2829
2830        if function.verify(true) {
2831            Ok(function)
2832        } else {
2833            function.print_to_stderr();
2834            unsafe {
2835                function.delete();
2836            }
2837            Err(anyhow!("Invalid generated function."))
2838        }
2839    }
2840
2841    pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2842        self.clear();
2843        let void_type = self.context.void_type();
2844        let fn_type = void_type.fn_type(
2845            &[
2846                self.real_type.into(),
2847                self.real_ptr_type.into(),
2848                self.real_ptr_type.into(),
2849                self.real_ptr_type.into(),
2850                self.int_type.into(),
2851                self.int_type.into(),
2852            ],
2853            false,
2854        );
2855        let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
2856        let function = self.module.add_function("mass", fn_type, None);
2857
2858        // add noalias
2859        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2860        let noalign = self.context.create_enum_attribute(alias_id, 0);
2861        for i in &[1, 2, 3] {
2862            function.add_attribute(AttributeLoc::Param(*i), noalign);
2863        }
2864
2865        let basic_block = self.context.append_basic_block(function, "entry");
2866        self.fn_value_opt = Some(function);
2867        self.builder.position_at_end(basic_block);
2868
2869        for (i, arg) in function.get_param_iter().enumerate() {
2870            let name = fn_arg_names[i];
2871            let alloca = self.function_arg_alloca(name, arg);
2872            self.insert_param(name, alloca);
2873        }
2874
2875        // only put code in this function if we have a state_dot and lhs
2876        if model.state_dot().is_some() && model.lhs().is_some() {
2877            let state_dot = model.state_dot().unwrap();
2878            let lhs = model.lhs().unwrap();
2879
2880            self.insert_dot_state(state_dot);
2881            self.insert_data(model);
2882            self.insert_indices();
2883
2884            // calculate time dependant definitions
2885            let mut nbarriers = 0;
2886            let total_barriers =
2887                (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
2888            for tensor in model.time_dep_defns() {
2889                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2890                self.jit_compile_call_barrier(nbarriers, total_barriers);
2891                nbarriers += 1;
2892            }
2893
2894            for a in model.dstate_dep_defns() {
2895                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2896                self.jit_compile_call_barrier(nbarriers, total_barriers);
2897                nbarriers += 1;
2898            }
2899
2900            // mass
2901            let res_ptr = self.get_param("rr");
2902            self.jit_compile_tensor(lhs, Some(*res_ptr))?;
2903            self.jit_compile_call_barrier(nbarriers, total_barriers);
2904        }
2905
2906        self.builder.build_return(None)?;
2907
2908        if function.verify(true) {
2909            Ok(function)
2910        } else {
2911            function.print_to_stderr();
2912            unsafe {
2913                function.delete();
2914            }
2915            Err(anyhow!("Invalid generated function."))
2916        }
2917    }
2918
2919    pub fn compile_gradient(
2920        &mut self,
2921        original_function: FunctionValue<'ctx>,
2922        args_type: &[CompileGradientArgType],
2923        mode: CompileMode,
2924        fn_name: &str,
2925    ) -> Result<FunctionValue<'ctx>> {
2926        self.clear();
2927
2928        // construct the gradient function
2929        let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
2930
2931        let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
2932
2933        let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
2934        let mut start_param_index: Vec<u32> = Vec::new();
2935        let mut ptr_arg_indices: Vec<u32> = Vec::new();
2936        for (i, arg) in original_function.get_param_iter().enumerate() {
2937            start_param_index.push(u32::try_from(fn_type.len()).unwrap());
2938            let arg_type = arg.get_type();
2939            fn_type.push(arg_type.into());
2940
2941            // constant args with type T in the original funciton have 2 args of type [int, T]
2942            enzyme_fn_type.push(self.int_type.into());
2943            enzyme_fn_type.push(arg.get_type().into());
2944
2945            if arg_type.is_pointer_type() {
2946                ptr_arg_indices.push(u32::try_from(i).unwrap());
2947            }
2948
2949            match args_type[i] {
2950                CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
2951                    fn_type.push(arg.get_type().into());
2952                    enzyme_fn_type.push(arg.get_type().into());
2953                }
2954                CompileGradientArgType::Const => {}
2955            }
2956        }
2957        let void_type = self.context.void_type();
2958        let fn_type = void_type.fn_type(fn_type.as_slice(), false);
2959        let function = self.module.add_function(fn_name, fn_type, None);
2960
2961        // add noalias
2962        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2963        let noalign = self.context.create_enum_attribute(alias_id, 0);
2964        for i in ptr_arg_indices {
2965            function.add_attribute(AttributeLoc::Param(i), noalign);
2966        }
2967
2968        let basic_block = self.context.append_basic_block(function, "entry");
2969        self.fn_value_opt = Some(function);
2970        self.builder.position_at_end(basic_block);
2971
2972        let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
2973        let mut input_activity = Vec::new();
2974        let mut arg_trees = Vec::new();
2975        for (i, arg) in original_function.get_param_iter().enumerate() {
2976            let param_index = start_param_index[i];
2977            let fn_arg = function.get_nth_param(param_index).unwrap();
2978
2979            // we'll probably only get double or pointers to doubles, so let assume this for now
2980            // todo: perhaps refactor this into a recursive function, might be overkill
2981            let concrete_type = match arg.get_type() {
2982                BasicTypeEnum::PointerType(_) => CConcreteType_DT_Pointer,
2983                BasicTypeEnum::FloatType(t) => {
2984                    if t == self.context.f64_type() {
2985                        CConcreteType_DT_Double
2986                    } else {
2987                        panic!("unsupported type")
2988                    }
2989                }
2990                BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
2991                _ => panic!("unsupported type"),
2992            };
2993            let new_tree = unsafe {
2994                EnzymeNewTypeTreeCT(
2995                    concrete_type,
2996                    self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
2997                )
2998            };
2999            unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3000
3001            // pointer to double
3002            if concrete_type == CConcreteType_DT_Pointer {
3003                // assume the pointer is to a double
3004                let inner_concrete_type = CConcreteType_DT_Double;
3005                let inner_new_tree = unsafe {
3006                    EnzymeNewTypeTreeCT(
3007                        inner_concrete_type,
3008                        self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3009                    )
3010                };
3011                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3012                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3013                unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
3014            }
3015
3016            arg_trees.push(new_tree);
3017            match args_type[i] {
3018                CompileGradientArgType::Dup => {
3019                    // pass in the arg value
3020                    enzyme_fn_args.push(fn_arg.into());
3021
3022                    // pass in the darg value
3023                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3024                    enzyme_fn_args.push(fn_darg.into());
3025
3026                    input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3027                }
3028                CompileGradientArgType::DupNoNeed => {
3029                    // pass in the arg value
3030                    enzyme_fn_args.push(fn_arg.into());
3031
3032                    // pass in the darg value
3033                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3034                    enzyme_fn_args.push(fn_darg.into());
3035
3036                    input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3037                }
3038                CompileGradientArgType::Const => {
3039                    // pass in the arg value
3040                    enzyme_fn_args.push(fn_arg.into());
3041
3042                    input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3043                }
3044            }
3045        }
3046        // if we have void ret, this must be false;
3047        let ret_primary_ret = false;
3048        let diff_ret = false;
3049        let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3050        let ret_tree = unsafe {
3051            EnzymeNewTypeTreeCT(
3052                CConcreteType_DT_Anything,
3053                self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3054            )
3055        };
3056
3057        // always optimize
3058        let fnc_opt_base = true;
3059        let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3060
3061        let kv_tmp = IntList {
3062            data: std::ptr::null_mut(),
3063            size: 0,
3064        };
3065        let mut known_values = vec![kv_tmp; input_activity.len()];
3066
3067        let fn_type_info = CFnTypeInfo {
3068            Arguments: arg_trees.as_mut_ptr(),
3069            Return: ret_tree,
3070            KnownValues: known_values.as_mut_ptr(),
3071        };
3072
3073        let type_analysis: EnzymeTypeAnalysisRef =
3074            unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3075
3076        let mut args_uncacheable = vec![0; arg_trees.len()];
3077
3078        let enzyme_function = match mode {
3079            CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3080                EnzymeCreateForwardDiff(
3081                    logic_ref, // Logic
3082                    std::ptr::null_mut(),
3083                    std::ptr::null_mut(),
3084                    original_function.as_value_ref(),
3085                    ret_activity, // LLVM function, return type
3086                    input_activity.as_mut_ptr(),
3087                    input_activity.len(), // constant arguments
3088                    type_analysis,        // type analysis struct
3089                    ret_primary_ret as u8,
3090                    CDerivativeMode_DEM_ForwardMode, // return value, dret_used, top_level which was 1
3091                    1,                               // free memory
3092                    0,                               // runtime activity
3093                    1,                               // vector mode width
3094                    std::ptr::null_mut(),
3095                    fn_type_info, // additional_arg, type info (return + args)
3096                    args_uncacheable.as_mut_ptr(),
3097                    args_uncacheable.len(), // uncacheable arguments
3098                    std::ptr::null_mut(),   // write augmented function to this
3099                )
3100            },
3101            CompileMode::Reverse | CompileMode::ReverseSens => {
3102                let mut call_enzyme = || unsafe {
3103                    EnzymeCreatePrimalAndGradient(
3104                        logic_ref,
3105                        std::ptr::null_mut(),
3106                        std::ptr::null_mut(),
3107                        original_function.as_value_ref(),
3108                        ret_activity,
3109                        input_activity.as_mut_ptr(),
3110                        input_activity.len(),
3111                        type_analysis,
3112                        ret_primary_ret as u8,
3113                        diff_ret as u8,
3114                        CDerivativeMode_DEM_ReverseModeCombined,
3115                        0,
3116                        1,
3117                        1,
3118                        std::ptr::null_mut(),
3119                        0,
3120                        fn_type_info,
3121                        args_uncacheable.as_mut_ptr(),
3122                        args_uncacheable.len(),
3123                        std::ptr::null_mut(),
3124                        if self.threaded { 1 } else { 0 },
3125                    )
3126                };
3127                if self.threaded {
3128                    // the register call handler alters a global variable, so we need to lock it
3129                    let _lock = my_mutex.lock().unwrap();
3130                    let barrier_string = CString::new("barrier").unwrap();
3131                    unsafe {
3132                        EnzymeRegisterCallHandler(
3133                            barrier_string.as_ptr(),
3134                            Some(fwd_handler),
3135                            Some(rev_handler),
3136                        )
3137                    };
3138                    let ret = call_enzyme();
3139                    // unregister it so some other thread doesn't use it
3140                    unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3141                    ret
3142                } else {
3143                    call_enzyme()
3144                }
3145            }
3146        };
3147
3148        // free everything
3149        unsafe { FreeEnzymeLogic(logic_ref) };
3150        unsafe { FreeTypeAnalysis(type_analysis) };
3151        unsafe { EnzymeFreeTypeTree(ret_tree) };
3152        for tree in arg_trees {
3153            unsafe { EnzymeFreeTypeTree(tree) };
3154        }
3155
3156        // call enzyme function
3157        let enzyme_function =
3158            unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3159        self.builder
3160            .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3161
3162        // return
3163        self.builder.build_return(None)?;
3164
3165        if function.verify(true) {
3166            Ok(function)
3167        } else {
3168            function.print_to_stderr();
3169            enzyme_function.print_to_stderr();
3170            unsafe {
3171                function.delete();
3172            }
3173            Err(anyhow!("Invalid generated function."))
3174        }
3175    }
3176
3177    pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3178        self.clear();
3179        let fn_type = self.context.void_type().fn_type(
3180            &[
3181                self.int_ptr_type.into(),
3182                self.int_ptr_type.into(),
3183                self.int_ptr_type.into(),
3184                self.int_ptr_type.into(),
3185                self.int_ptr_type.into(),
3186                self.int_ptr_type.into(),
3187            ],
3188            false,
3189        );
3190
3191        let function = self.module.add_function("get_dims", fn_type, None);
3192        let block = self.context.append_basic_block(function, "entry");
3193        let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3194        self.builder.position_at_end(block);
3195
3196        for (i, arg) in function.get_param_iter().enumerate() {
3197            let name = fn_arg_names[i];
3198            let alloca = self.function_arg_alloca(name, arg);
3199            self.insert_param(name, alloca);
3200        }
3201
3202        self.insert_indices();
3203
3204        let number_of_states = model.state().nnz() as u64;
3205        let number_of_inputs = model.inputs().iter().fold(0, |acc, x| acc + x.nnz()) as u64;
3206        let number_of_outputs = match model.out() {
3207            Some(out) => out.nnz() as u64,
3208            None => 0,
3209        };
3210        let number_of_stop = if let Some(stop) = model.stop() {
3211            stop.nnz() as u64
3212        } else {
3213            0
3214        };
3215        let has_mass = match model.lhs().is_some() {
3216            true => 1u64,
3217            false => 0u64,
3218        };
3219        let data_len = self.layout.data().len() as u64;
3220        self.builder.build_store(
3221            *self.get_param("states"),
3222            self.int_type.const_int(number_of_states, false),
3223        )?;
3224        self.builder.build_store(
3225            *self.get_param("inputs"),
3226            self.int_type.const_int(number_of_inputs, false),
3227        )?;
3228        self.builder.build_store(
3229            *self.get_param("outputs"),
3230            self.int_type.const_int(number_of_outputs, false),
3231        )?;
3232        self.builder.build_store(
3233            *self.get_param("data"),
3234            self.int_type.const_int(data_len, false),
3235        )?;
3236        self.builder.build_store(
3237            *self.get_param("stop"),
3238            self.int_type.const_int(number_of_stop, false),
3239        )?;
3240        self.builder.build_store(
3241            *self.get_param("has_mass"),
3242            self.int_type.const_int(has_mass, false),
3243        )?;
3244        self.builder.build_return(None)?;
3245
3246        if function.verify(true) {
3247            Ok(function)
3248        } else {
3249            function.print_to_stderr();
3250            unsafe {
3251                function.delete();
3252            }
3253            Err(anyhow!("Invalid generated function."))
3254        }
3255    }
3256
3257    pub fn compile_get_tensor(
3258        &mut self,
3259        model: &DiscreteModel,
3260        name: &str,
3261    ) -> Result<FunctionValue<'ctx>> {
3262        self.clear();
3263        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3264        let fn_type = self.context.void_type().fn_type(
3265            &[
3266                self.real_ptr_type.into(),
3267                real_ptr_ptr_type.into(),
3268                self.int_ptr_type.into(),
3269            ],
3270            false,
3271        );
3272        let function_name = format!("get_tensor_{name}");
3273        let function = self
3274            .module
3275            .add_function(function_name.as_str(), fn_type, None);
3276        let basic_block = self.context.append_basic_block(function, "entry");
3277        self.fn_value_opt = Some(function);
3278
3279        let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3280        self.builder.position_at_end(basic_block);
3281
3282        for (i, arg) in function.get_param_iter().enumerate() {
3283            let name = fn_arg_names[i];
3284            let alloca = self.function_arg_alloca(name, arg);
3285            self.insert_param(name, alloca);
3286        }
3287
3288        self.insert_data(model);
3289        let ptr = self.get_param(name);
3290        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3291        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3292        self.builder
3293            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3294        self.builder
3295            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3296        self.builder.build_return(None)?;
3297
3298        if function.verify(true) {
3299            Ok(function)
3300        } else {
3301            function.print_to_stderr();
3302            unsafe {
3303                function.delete();
3304            }
3305            Err(anyhow!("Invalid generated function."))
3306        }
3307    }
3308
3309    pub fn compile_get_constant(
3310        &mut self,
3311        model: &DiscreteModel,
3312        name: &str,
3313    ) -> Result<FunctionValue<'ctx>> {
3314        self.clear();
3315        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3316        let fn_type = self
3317            .context
3318            .void_type()
3319            .fn_type(&[real_ptr_ptr_type.into(), self.int_ptr_type.into()], false);
3320        let function_name = format!("get_constant_{name}");
3321        let function = self
3322            .module
3323            .add_function(function_name.as_str(), fn_type, None);
3324        let basic_block = self.context.append_basic_block(function, "entry");
3325        self.fn_value_opt = Some(function);
3326
3327        let fn_arg_names = &["tensor_data", "tensor_size"];
3328        self.builder.position_at_end(basic_block);
3329
3330        for (i, arg) in function.get_param_iter().enumerate() {
3331            let name = fn_arg_names[i];
3332            let alloca = self.function_arg_alloca(name, arg);
3333            self.insert_param(name, alloca);
3334        }
3335
3336        self.insert_constants(model);
3337        let ptr = self.get_param(name);
3338        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3339        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3340        self.builder
3341            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3342        self.builder
3343            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3344        self.builder.build_return(None)?;
3345
3346        if function.verify(true) {
3347            Ok(function)
3348        } else {
3349            function.print_to_stderr();
3350            unsafe {
3351                function.delete();
3352            }
3353            Err(anyhow!("Invalid generated function."))
3354        }
3355    }
3356
3357    pub fn compile_inputs(
3358        &mut self,
3359        model: &DiscreteModel,
3360        is_get: bool,
3361    ) -> Result<FunctionValue<'ctx>> {
3362        self.clear();
3363        let void_type = self.context.void_type();
3364        let fn_type = void_type.fn_type(
3365            &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3366            false,
3367        );
3368        let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3369        let function = self.module.add_function(function_name, fn_type, None);
3370        let mut block = self.context.append_basic_block(function, "entry");
3371        self.fn_value_opt = Some(function);
3372
3373        let fn_arg_names = &["inputs", "data"];
3374        self.builder.position_at_end(block);
3375
3376        for (i, arg) in function.get_param_iter().enumerate() {
3377            let name = fn_arg_names[i];
3378            let alloca = self.function_arg_alloca(name, arg);
3379            self.insert_param(name, alloca);
3380        }
3381
3382        let mut inputs_index = 0usize;
3383        for input in model.inputs() {
3384            let name = format!("input_{}", input.name());
3385            self.insert_tensor(input, false);
3386            let ptr = self.get_var(input);
3387            // loop thru the elements of this input and set/get them using the inputs ptr
3388            let inputs_start_index = self.int_type.const_int(inputs_index as u64, false);
3389            let start_index = self.int_type.const_int(0, false);
3390            let end_index = self
3391                .int_type
3392                .const_int(input.nnz().try_into().unwrap(), false);
3393
3394            let input_block = self.context.append_basic_block(function, name.as_str());
3395            self.builder.build_unconditional_branch(input_block)?;
3396            self.builder.position_at_end(input_block);
3397            let index = self.builder.build_phi(self.int_type, "i")?;
3398            index.add_incoming(&[(&start_index, block)]);
3399
3400            // loop body - copy value from inputs to data
3401            let curr_input_index = index.as_basic_value().into_int_value();
3402            let input_ptr = Self::get_ptr_to_index(
3403                &self.builder,
3404                self.real_type,
3405                ptr,
3406                curr_input_index,
3407                name.as_str(),
3408            );
3409            let curr_inputs_index =
3410                self.builder
3411                    .build_int_add(inputs_start_index, curr_input_index, name.as_str())?;
3412            let inputs_ptr = Self::get_ptr_to_index(
3413                &self.builder,
3414                self.real_type,
3415                self.get_param("inputs"),
3416                curr_inputs_index,
3417                name.as_str(),
3418            );
3419            if is_get {
3420                let input_value = self
3421                    .build_load(self.real_type, input_ptr, name.as_str())?
3422                    .into_float_value();
3423                self.builder.build_store(inputs_ptr, input_value)?;
3424            } else {
3425                let input_value = self
3426                    .build_load(self.real_type, inputs_ptr, name.as_str())?
3427                    .into_float_value();
3428                self.builder.build_store(input_ptr, input_value)?;
3429            }
3430
3431            // increment loop index
3432            let one = self.int_type.const_int(1, false);
3433            let next_index = self
3434                .builder
3435                .build_int_add(curr_input_index, one, name.as_str())?;
3436            index.add_incoming(&[(&next_index, input_block)]);
3437
3438            // loop condition
3439            let loop_while = self.builder.build_int_compare(
3440                IntPredicate::ULT,
3441                next_index,
3442                end_index,
3443                name.as_str(),
3444            )?;
3445            let post_block = self.context.append_basic_block(function, name.as_str());
3446            self.builder
3447                .build_conditional_branch(loop_while, input_block, post_block)?;
3448            self.builder.position_at_end(post_block);
3449
3450            // get ready for next input
3451            block = post_block;
3452            inputs_index += input.nnz();
3453        }
3454        self.builder.build_return(None)?;
3455
3456        if function.verify(true) {
3457            Ok(function)
3458        } else {
3459            function.print_to_stderr();
3460            unsafe {
3461                function.delete();
3462            }
3463            Err(anyhow!("Invalid generated function."))
3464        }
3465    }
3466
3467    pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3468        self.clear();
3469        let void_type = self.context.void_type();
3470        let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false);
3471        let function = self.module.add_function("set_id", fn_type, None);
3472        let mut block = self.context.append_basic_block(function, "entry");
3473
3474        let fn_arg_names = &["id"];
3475        self.builder.position_at_end(block);
3476
3477        for (i, arg) in function.get_param_iter().enumerate() {
3478            let name = fn_arg_names[i];
3479            let alloca = self.function_arg_alloca(name, arg);
3480            self.insert_param(name, alloca);
3481        }
3482
3483        let mut id_index = 0usize;
3484        for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3485            let name = blk.name().unwrap_or("unknown");
3486            // loop thru the elements of this state blk and set the corresponding elements of id
3487            let id_start_index = self.int_type.const_int(id_index as u64, false);
3488            let blk_start_index = self.int_type.const_int(0, false);
3489            let blk_end_index = self
3490                .int_type
3491                .const_int(blk.nnz().try_into().unwrap(), false);
3492
3493            let blk_block = self.context.append_basic_block(function, name);
3494            self.builder.build_unconditional_branch(blk_block)?;
3495            self.builder.position_at_end(blk_block);
3496            let index = self.builder.build_phi(self.int_type, "i")?;
3497            index.add_incoming(&[(&blk_start_index, block)]);
3498
3499            // loop body - copy value from inputs to data
3500            let curr_blk_index = index.as_basic_value().into_int_value();
3501            let curr_id_index = self
3502                .builder
3503                .build_int_add(id_start_index, curr_blk_index, name)?;
3504            let id_ptr = Self::get_ptr_to_index(
3505                &self.builder,
3506                self.real_type,
3507                self.get_param("id"),
3508                curr_id_index,
3509                name,
3510            );
3511            let is_algebraic_float = if *is_algebraic {
3512                0.0 as RealType
3513            } else {
3514                1.0 as RealType
3515            };
3516            let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3517            self.builder.build_store(id_ptr, is_algebraic_value)?;
3518
3519            // increment loop index
3520            let one = self.int_type.const_int(1, false);
3521            let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3522            index.add_incoming(&[(&next_index, blk_block)]);
3523
3524            // loop condition
3525            let loop_while = self.builder.build_int_compare(
3526                IntPredicate::ULT,
3527                next_index,
3528                blk_end_index,
3529                name,
3530            )?;
3531            let post_block = self.context.append_basic_block(function, name);
3532            self.builder
3533                .build_conditional_branch(loop_while, blk_block, post_block)?;
3534            self.builder.position_at_end(post_block);
3535
3536            // get ready for next blk
3537            block = post_block;
3538            id_index += blk.nnz();
3539        }
3540        self.builder.build_return(None)?;
3541
3542        if function.verify(true) {
3543            Ok(function)
3544        } else {
3545            function.print_to_stderr();
3546            unsafe {
3547                function.delete();
3548            }
3549            Err(anyhow!("Invalid generated function."))
3550        }
3551    }
3552
3553    pub fn module(&self) -> &Module<'ctx> {
3554        &self.module
3555    }
3556}