diffsl/execution/llvm/
codegen.rs

1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use inkwell::attributes::{Attribute, AttributeLoc};
4use inkwell::basic_block::BasicBlock;
5use inkwell::builder::Builder;
6use inkwell::context::{AsContextRef, Context};
7use inkwell::execution_engine::ExecutionEngine;
8use inkwell::intrinsics::Intrinsic;
9use inkwell::module::{Linkage, Module};
10use inkwell::passes::PassBuilderOptions;
11use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
12use inkwell::types::{
13    BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
14};
15use inkwell::values::{
16    AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
17    FunctionValue, GlobalValue, IntValue, PointerValue,
18};
19use inkwell::{
20    AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
21    OptimizationLevel,
22};
23use llvm_sys::core::{
24    LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
25    LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType,
26};
27use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
28use std::collections::HashMap;
29use std::ffi::CString;
30use std::iter::zip;
31use std::pin::Pin;
32use target_lexicon::Triple;
33
34type RealType = f64;
35
36use crate::ast::{Ast, AstKind};
37use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
38use crate::enzyme::{
39    CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Integer,
40    CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
41    CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
42    DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
43    EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
44    EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
45    FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
46    CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
47};
48use crate::execution::compiler::CompilerMode;
49use crate::execution::module::{
50    CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
51};
52use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
53use lazy_static::lazy_static;
54use std::sync::Mutex;
55
56lazy_static! {
57    static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
58}
59
60struct ImmovableLlvmModule {
61    // actually has lifetime of `context`
62    // declared first so it's droped before `context`
63    codegen: Option<CodeGen<'static>>,
64    // safety: we must never move out of this box as long as codgen is alive
65    context: AliasableBox<Context>,
66    _pin: std::marker::PhantomPinned,
67}
68
69pub struct LlvmModule {
70    inner: Pin<Box<ImmovableLlvmModule>>,
71    machine: TargetMachine,
72}
73
74unsafe impl Sync for LlvmModule {}
75
76impl LlvmModule {
77    fn new(triple: Option<Triple>, model: &DiscreteModel, threaded: bool) -> Result<Self> {
78        let initialization_config = &InitializationConfig::default();
79        Target::initialize_all(initialization_config);
80        let host_triple = Triple::host();
81        let (triple_str, native) = match triple {
82            Some(ref triple) => (triple.to_string(), false),
83            None => (host_triple.to_string(), true),
84        };
85        let triple = TargetTriple::create(triple_str.as_str());
86        let target = Target::from_triple(&triple).unwrap();
87        let cpu = if native {
88            TargetMachine::get_host_cpu_name().to_string()
89        } else {
90            "generic".to_string()
91        };
92        let features = if native {
93            TargetMachine::get_host_cpu_features().to_string()
94        } else {
95            "".to_string()
96        };
97        let machine = target
98            .create_target_machine(
99                &triple,
100                cpu.as_str(),
101                features.as_str(),
102                inkwell::OptimizationLevel::Aggressive,
103                inkwell::targets::RelocMode::Default,
104                inkwell::targets::CodeModel::Default,
105            )
106            .unwrap();
107
108        let context = AliasableBox::from_unique(Box::new(Context::create()));
109        let mut pinned = Self {
110            inner: Box::pin(ImmovableLlvmModule {
111                codegen: None,
112                context,
113                _pin: std::marker::PhantomPinned,
114            }),
115            machine,
116        };
117
118        let context_ref = pinned.inner.context.as_ref();
119        let real_type_str = "f64";
120        let codegen = CodeGen::new(
121            model,
122            context_ref,
123            context_ref.f64_type(),
124            context_ref.i32_type(),
125            real_type_str,
126            threaded,
127        )?;
128        let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
129        unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
130        Ok(pinned)
131    }
132
133    fn pre_autodiff_optimisation(&mut self) -> Result<()> {
134        //let pass_manager_builder = PassManagerBuilder::create();
135        //pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default);
136        //let pass_manager = PassManager::create(());
137        //pass_manager_builder.populate_module_pass_manager(&pass_manager);
138        //pass_manager.run_on(self.codegen().module());
139
140        //self.codegen().module().print_to_stderr();
141        // optimise at -O2 no unrolling before giving to enzyme
142        let pass_options = PassBuilderOptions::create();
143        //pass_options.set_verify_each(true);
144        //pass_options.set_debug_logging(true);
145        //pass_options.set_loop_interleaving(true);
146        pass_options.set_loop_vectorization(false);
147        pass_options.set_loop_slp_vectorization(false);
148        pass_options.set_loop_unrolling(false);
149        //pass_options.set_forget_all_scev_in_loop_unroll(true);
150        //pass_options.set_licm_mssa_opt_cap(1);
151        //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10);
152        //pass_options.set_call_graph_profile(true);
153        //pass_options.set_merge_functions(true);
154
155        //let passes = "default<O2>";
156        let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
157        let (codegen, machine) = self.codegen_and_machine_mut();
158        codegen
159            .module()
160            .run_passes(passes, machine, pass_options)
161            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
162    }
163
164    fn post_autodiff_optimisation(&mut self) -> Result<()> {
165        // remove noinline attribute from barrier function as only needed for enzyme
166        if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
167            let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
168            barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
169        }
170
171        // remove all preprocess_* functions
172        for f in self.codegen_mut().module.get_functions() {
173            if f.get_name().to_str().unwrap().starts_with("preprocess_") {
174                unsafe { f.delete() };
175            }
176        }
177
178        //self.codegen()
179        //    .module()
180        //    .print_to_file("post_autodiff_optimisation.ll")
181        //    .unwrap();
182
183        let passes = "default<O3>";
184        let (codegen, machine) = self.codegen_and_machine_mut();
185        codegen
186            .module()
187            .run_passes(passes, machine, PassBuilderOptions::create())
188            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
189
190        Ok(())
191    }
192
193    pub fn print(&self) {
194        self.codegen().module().print_to_stderr();
195    }
196    fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
197        unsafe {
198            self.inner
199                .as_mut()
200                .get_unchecked_mut()
201                .codegen
202                .as_mut()
203                .unwrap()
204        }
205    }
206    fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
207        (
208            unsafe {
209                self.inner
210                    .as_mut()
211                    .get_unchecked_mut()
212                    .codegen
213                    .as_mut()
214                    .unwrap()
215            },
216            &self.machine,
217        )
218    }
219
220    fn codegen(&self) -> &CodeGen<'static> {
221        self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
222    }
223}
224
225impl CodegenModule for LlvmModule {}
226
227impl CodegenModuleCompile for LlvmModule {
228    fn from_discrete_model(
229        model: &DiscreteModel,
230        mode: CompilerMode,
231        triple: Option<Triple>,
232    ) -> Result<Self> {
233        let thread_dim = mode.thread_dim(model.state().nnz());
234        let threaded = thread_dim > 1;
235
236        let mut module = Self::new(triple, model, threaded)?;
237
238        let set_u0 = module.codegen_mut().compile_set_u0(model)?;
239        let _calc_stop = module.codegen_mut().compile_calc_stop(model)?;
240        let rhs = module.codegen_mut().compile_rhs(model, false)?;
241        let rhs_full = module.codegen_mut().compile_rhs(model, true)?;
242        let mass = module.codegen_mut().compile_mass(model)?;
243        let calc_out = module.codegen_mut().compile_calc_out(model, false)?;
244        let calc_out_full = module.codegen_mut().compile_calc_out(model, true)?;
245        let _set_id = module.codegen_mut().compile_set_id(model)?;
246        let _get_dims = module.codegen_mut().compile_get_dims(model)?;
247        let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
248        let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
249        let _set_constants = module.codegen_mut().compile_set_constants(model)?;
250        let tensor_info = module
251            .codegen()
252            .layout
253            .tensors()
254            .map(|(name, is_constant)| (name.to_string(), is_constant))
255            .collect::<Vec<_>>();
256        for (tensor, is_constant) in tensor_info {
257            if is_constant {
258                module
259                    .codegen_mut()
260                    .compile_get_constant(model, tensor.as_str())?;
261            } else {
262                module
263                    .codegen_mut()
264                    .compile_get_tensor(model, tensor.as_str())?;
265            }
266        }
267
268        module.pre_autodiff_optimisation()?;
269
270        module.codegen_mut().compile_gradient(
271            set_u0,
272            &[
273                CompileGradientArgType::DupNoNeed,
274                CompileGradientArgType::DupNoNeed,
275                CompileGradientArgType::Const,
276                CompileGradientArgType::Const,
277            ],
278            CompileMode::Forward,
279            "set_u0_grad",
280        )?;
281
282        module.codegen_mut().compile_gradient(
283            rhs,
284            &[
285                CompileGradientArgType::Const,
286                CompileGradientArgType::DupNoNeed,
287                CompileGradientArgType::DupNoNeed,
288                CompileGradientArgType::DupNoNeed,
289                CompileGradientArgType::Const,
290                CompileGradientArgType::Const,
291            ],
292            CompileMode::Forward,
293            "rhs_grad",
294        )?;
295
296        module.codegen_mut().compile_gradient(
297            calc_out,
298            &[
299                CompileGradientArgType::Const,
300                CompileGradientArgType::DupNoNeed,
301                CompileGradientArgType::DupNoNeed,
302                CompileGradientArgType::DupNoNeed,
303                CompileGradientArgType::Const,
304                CompileGradientArgType::Const,
305            ],
306            CompileMode::Forward,
307            "calc_out_grad",
308        )?;
309        module.codegen_mut().compile_gradient(
310            set_inputs,
311            &[
312                CompileGradientArgType::DupNoNeed,
313                CompileGradientArgType::DupNoNeed,
314            ],
315            CompileMode::Forward,
316            "set_inputs_grad",
317        )?;
318
319        module.codegen_mut().compile_gradient(
320            set_u0,
321            &[
322                CompileGradientArgType::DupNoNeed,
323                CompileGradientArgType::DupNoNeed,
324                CompileGradientArgType::Const,
325                CompileGradientArgType::Const,
326            ],
327            CompileMode::Reverse,
328            "set_u0_rgrad",
329        )?;
330
331        module.codegen_mut().compile_gradient(
332            mass,
333            &[
334                CompileGradientArgType::Const,
335                CompileGradientArgType::DupNoNeed,
336                CompileGradientArgType::DupNoNeed,
337                CompileGradientArgType::DupNoNeed,
338                CompileGradientArgType::Const,
339                CompileGradientArgType::Const,
340            ],
341            CompileMode::Reverse,
342            "mass_rgrad",
343        )?;
344
345        module.codegen_mut().compile_gradient(
346            rhs,
347            &[
348                CompileGradientArgType::Const,
349                CompileGradientArgType::DupNoNeed,
350                CompileGradientArgType::DupNoNeed,
351                CompileGradientArgType::DupNoNeed,
352                CompileGradientArgType::Const,
353                CompileGradientArgType::Const,
354            ],
355            CompileMode::Reverse,
356            "rhs_rgrad",
357        )?;
358        module.codegen_mut().compile_gradient(
359            calc_out,
360            &[
361                CompileGradientArgType::Const,
362                CompileGradientArgType::DupNoNeed,
363                CompileGradientArgType::DupNoNeed,
364                CompileGradientArgType::DupNoNeed,
365                CompileGradientArgType::Const,
366                CompileGradientArgType::Const,
367            ],
368            CompileMode::Reverse,
369            "calc_out_rgrad",
370        )?;
371
372        module.codegen_mut().compile_gradient(
373            set_inputs,
374            &[
375                CompileGradientArgType::DupNoNeed,
376                CompileGradientArgType::DupNoNeed,
377            ],
378            CompileMode::Reverse,
379            "set_inputs_rgrad",
380        )?;
381
382        module.codegen_mut().compile_gradient(
383            rhs_full,
384            &[
385                CompileGradientArgType::Const,
386                CompileGradientArgType::Const,
387                CompileGradientArgType::DupNoNeed,
388                CompileGradientArgType::DupNoNeed,
389                CompileGradientArgType::Const,
390                CompileGradientArgType::Const,
391            ],
392            CompileMode::ForwardSens,
393            "rhs_sgrad",
394        )?;
395
396        module.codegen_mut().compile_gradient(
397            calc_out_full,
398            &[
399                CompileGradientArgType::Const,
400                CompileGradientArgType::Const,
401                CompileGradientArgType::DupNoNeed,
402                CompileGradientArgType::DupNoNeed,
403                CompileGradientArgType::Const,
404                CompileGradientArgType::Const,
405            ],
406            CompileMode::ForwardSens,
407            "calc_out_sgrad",
408        )?;
409        module.codegen_mut().compile_gradient(
410            calc_out_full,
411            &[
412                CompileGradientArgType::Const,
413                CompileGradientArgType::Const,
414                CompileGradientArgType::DupNoNeed,
415                CompileGradientArgType::DupNoNeed,
416                CompileGradientArgType::Const,
417                CompileGradientArgType::Const,
418            ],
419            CompileMode::ReverseSens,
420            "calc_out_srgrad",
421        )?;
422
423        module.codegen_mut().compile_gradient(
424            rhs_full,
425            &[
426                CompileGradientArgType::Const,
427                CompileGradientArgType::Const,
428                CompileGradientArgType::DupNoNeed,
429                CompileGradientArgType::DupNoNeed,
430                CompileGradientArgType::Const,
431                CompileGradientArgType::Const,
432            ],
433            CompileMode::ReverseSens,
434            "rhs_srgrad",
435        )?;
436
437        module.post_autodiff_optimisation()?;
438        Ok(module)
439    }
440}
441
442impl CodegenModuleEmit for LlvmModule {
443    fn to_object(self) -> Result<Vec<u8>> {
444        let module = self.codegen().module();
445        //module.print_to_stderr();
446        let buffer = self
447            .machine
448            .write_to_memory_buffer(module, FileType::Object)
449            .unwrap()
450            .as_slice()
451            .to_vec();
452        Ok(buffer)
453    }
454}
455impl CodegenModuleJit for LlvmModule {
456    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
457        let ee = self
458            .codegen()
459            .module()
460            .create_jit_execution_engine(OptimizationLevel::Aggressive)
461            .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
462
463        let module = self.codegen().module();
464        let mut symbols = HashMap::new();
465        for function in module.get_functions() {
466            let name = function.get_name().to_str().unwrap();
467            let address = ee.get_function_address(name);
468            if let Ok(address) = address {
469                symbols.insert(name.to_string(), address as *const u8);
470            }
471        }
472        Ok(symbols)
473    }
474}
475
476struct Globals<'ctx> {
477    indices: Option<GlobalValue<'ctx>>,
478    constants: Option<GlobalValue<'ctx>>,
479    thread_counter: Option<GlobalValue<'ctx>>,
480}
481
482impl<'ctx> Globals<'ctx> {
483    fn new(
484        layout: &DataLayout,
485        module: &Module<'ctx>,
486        int_type: IntType<'ctx>,
487        real_type: FloatType<'ctx>,
488        threaded: bool,
489    ) -> Self {
490        let thread_counter = if threaded {
491            let tc = module.add_global(
492                int_type,
493                Some(AddressSpace::default()),
494                "enzyme_const_thread_counter",
495            );
496            // todo: for some reason this doesn't make enzyme think it's inactive
497            // but using enzyme_const in the name does
498            // todo: also, adding this metadata causes the print of the module to segfault,
499            // so maybe a bug in inkwell
500            //let md_string = context.metadata_string("enzyme_inactive");
501            //tc.set_metadata(md_string, 0);
502            let tc_value = int_type.const_zero();
503            tc.set_visibility(GlobalVisibility::Hidden);
504            tc.set_initializer(&tc_value.as_basic_value_enum());
505            Some(tc)
506        } else {
507            None
508        };
509        let constants = if layout.constants().is_empty() {
510            None
511        } else {
512            let constants_array_type =
513                real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
514            let constants = module.add_global(
515                constants_array_type,
516                Some(AddressSpace::default()),
517                "enzyme_const_constants",
518            );
519            constants.set_visibility(GlobalVisibility::Hidden);
520            constants.set_constant(false);
521            constants.set_initializer(&constants_array_type.const_zero());
522            Some(constants)
523        };
524        let indices = if layout.indices().is_empty() {
525            None
526        } else {
527            let indices_array_type =
528                int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
529            let indices_array_values = layout
530                .indices()
531                .iter()
532                .map(|&i| int_type.const_int(i.try_into().unwrap(), false))
533                .collect::<Vec<IntValue>>();
534            let indices_value = int_type.const_array(indices_array_values.as_slice());
535            let indices = module.add_global(
536                indices_array_type,
537                Some(AddressSpace::default()),
538                "enzyme_const_indices",
539            );
540            indices.set_constant(true);
541            indices.set_visibility(GlobalVisibility::Hidden);
542            indices.set_initializer(&indices_value);
543            Some(indices)
544        };
545        Self {
546            indices,
547            thread_counter,
548            constants,
549        }
550    }
551}
552
553pub enum CompileGradientArgType {
554    Const,
555    Dup,
556    DupNoNeed,
557}
558
559pub enum CompileMode {
560    Forward,
561    ForwardSens,
562    Reverse,
563    ReverseSens,
564}
565
566pub struct CodeGen<'ctx> {
567    context: &'ctx inkwell::context::Context,
568    module: Module<'ctx>,
569    builder: Builder<'ctx>,
570    variables: HashMap<String, PointerValue<'ctx>>,
571    functions: HashMap<String, FunctionValue<'ctx>>,
572    fn_value_opt: Option<FunctionValue<'ctx>>,
573    tensor_ptr_opt: Option<PointerValue<'ctx>>,
574    real_type: FloatType<'ctx>,
575    real_ptr_type: PointerType<'ctx>,
576    real_type_str: String,
577    int_type: IntType<'ctx>,
578    int_ptr_type: PointerType<'ctx>,
579    layout: DataLayout,
580    globals: Globals<'ctx>,
581    threaded: bool,
582    _ee: Option<ExecutionEngine<'ctx>>,
583}
584
585unsafe extern "C" fn fwd_handler(
586    _builder: LLVMBuilderRef,
587    _call_instruction: LLVMValueRef,
588    _gutils: *mut GradientUtils,
589    _dcall: *mut LLVMValueRef,
590    _normal_return: *mut LLVMValueRef,
591    _shadow_return: *mut LLVMValueRef,
592) -> u8 {
593    1
594}
595
596unsafe extern "C" fn rev_handler(
597    builder: LLVMBuilderRef,
598    call_instruction: LLVMValueRef,
599    gutils: *mut DiffeGradientUtils,
600    _tape: LLVMValueRef,
601) {
602    let call_block = LLVMGetInstructionParent(call_instruction);
603    let call_function = LLVMGetBasicBlockParent(call_block);
604    let module = LLVMGetGlobalParent(call_function);
605    let name_c_str = CString::new("barrier_grad").unwrap();
606    let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
607    let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
608    let barrier_num = LLVMGetArgOperand(call_instruction, 0);
609    let total_barriers = LLVMGetArgOperand(call_instruction, 1);
610    let thread_count = LLVMGetArgOperand(call_instruction, 2);
611    let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
612    let total_barriers =
613        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
614    let thread_count =
615        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
616    let mut args = [barrier_num, total_barriers, thread_count];
617    let name_c_str = CString::new("").unwrap();
618    LLVMBuildCall2(
619        builder,
620        barrier_func_type,
621        barrier_func,
622        args.as_mut_ptr(),
623        args.len() as u32,
624        name_c_str.as_ptr(),
625    );
626}
627
628#[allow(dead_code)]
629enum PrintValue<'ctx> {
630    Real(FloatValue<'ctx>),
631    Int(IntValue<'ctx>),
632}
633
634impl<'ctx> CodeGen<'ctx> {
635    pub fn new(
636        model: &DiscreteModel,
637        context: &'ctx inkwell::context::Context,
638        real_type: FloatType<'ctx>,
639        int_type: IntType<'ctx>,
640        real_type_str: &str,
641        threaded: bool,
642    ) -> Result<Self> {
643        let builder = context.create_builder();
644        let layout = DataLayout::new(model);
645        let module = context.create_module(model.name());
646        let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
647        let real_ptr_type = Self::pointer_type(context, real_type.into());
648        let int_ptr_type = Self::pointer_type(context, int_type.into());
649        let mut ret = Self {
650            context,
651            module,
652            builder,
653            real_type,
654            real_ptr_type,
655            real_type_str: real_type_str.to_owned(),
656            variables: HashMap::new(),
657            functions: HashMap::new(),
658            fn_value_opt: None,
659            tensor_ptr_opt: None,
660            layout,
661            int_type,
662            int_ptr_type,
663            globals,
664            threaded,
665            _ee: None,
666        };
667        if threaded {
668            ret.compile_barrier_init()?;
669            ret.compile_barrier()?;
670            ret.compile_barrier_grad()?;
671            // todo: think I can remove this unless I want to call enzyme using a llvm pass
672            //ret.globals.add_registered_barrier(ret.context, &ret.module);
673        }
674        Ok(ret)
675    }
676
677    #[allow(dead_code)]
678    fn compile_print_value(
679        &mut self,
680        name: &str,
681        value: PrintValue<'ctx>,
682    ) -> Result<CallSiteValue> {
683        let void_type = self.context.void_type();
684        // int printf(const char *format, ...)
685        let printf_type = void_type.fn_type(&[self.int_ptr_type.into()], true);
686        // get printf function or declare it if it doesn't exist
687        let printf = match self.module.get_function("printf") {
688            Some(f) => f,
689            None => self
690                .module
691                .add_function("printf", printf_type, Some(Linkage::External)),
692        };
693        let (format_str, format_str_name) = match value {
694            PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
695            PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
696        };
697        // change format_str to c string
698        let format_str = CString::new(format_str).unwrap();
699        // if format_str_name doesn not already exist as a global, add it
700        let format_str_global = match self.module.get_global(format_str_name.as_str()) {
701            Some(g) => g,
702            None => {
703                let format_str = self.context.const_string(format_str.as_bytes(), true);
704                let fmt_str =
705                    self.module
706                        .add_global(format_str.get_type(), None, format_str_name.as_str());
707                fmt_str.set_initializer(&format_str);
708                fmt_str.set_visibility(GlobalVisibility::Hidden);
709                fmt_str
710            }
711        };
712        // call printf with the format string and the value
713        let format_str_ptr = self.builder.build_pointer_cast(
714            format_str_global.as_pointer_value(),
715            self.int_ptr_type,
716            "format_str_ptr",
717        )?;
718        let value: BasicMetadataValueEnum = match value {
719            PrintValue::Real(v) => v.into(),
720            PrintValue::Int(v) => v.into(),
721        };
722        self.builder
723            .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
724            .map_err(|e| anyhow!("Error building call to printf: {}", e))
725    }
726
727    fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
728        self.clear();
729        let void_type = self.context.void_type();
730        let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false);
731        let fn_arg_names = &["thread_id", "thread_dim"];
732        let function = self.module.add_function("set_constants", fn_type, None);
733
734        let basic_block = self.context.append_basic_block(function, "entry");
735        self.fn_value_opt = Some(function);
736        self.builder.position_at_end(basic_block);
737
738        for (i, arg) in function.get_param_iter().enumerate() {
739            let name = fn_arg_names[i];
740            let alloca = self.function_arg_alloca(name, arg);
741            self.insert_param(name, alloca);
742        }
743
744        self.insert_indices();
745        self.insert_constants(model);
746
747        let mut nbarriers = 0;
748        let total_barriers = (model.constant_defns().len()) as u64;
749        #[allow(clippy::explicit_counter_loop)]
750        for a in model.constant_defns() {
751            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
752            self.jit_compile_call_barrier(nbarriers, total_barriers);
753            nbarriers += 1;
754        }
755
756        self.builder.build_return(None)?;
757
758        if function.verify(true) {
759            Ok(function)
760        } else {
761            function.print_to_stderr();
762            unsafe {
763                function.delete();
764            }
765            Err(anyhow!("Invalid generated function."))
766        }
767    }
768
769    fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
770        self.clear();
771        let void_type = self.context.void_type();
772        let fn_type = void_type.fn_type(&[], false);
773        let function = self.module.add_function("barrier_init", fn_type, None);
774
775        let entry_block = self.context.append_basic_block(function, "entry");
776
777        self.fn_value_opt = Some(function);
778        self.builder.position_at_end(entry_block);
779
780        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
781        self.builder
782            .build_store(thread_counter, self.int_type.const_zero())?;
783
784        self.builder.build_return(None)?;
785
786        if function.verify(true) {
787            self.functions.insert("barrier_init".to_owned(), function);
788            Ok(function)
789        } else {
790            function.print_to_stderr();
791            unsafe {
792                function.delete();
793            }
794            Err(anyhow!("Invalid generated function."))
795        }
796    }
797
798    fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
799        self.clear();
800        let void_type = self.context.void_type();
801        let fn_type = void_type.fn_type(
802            &[
803                self.int_type.into(),
804                self.int_type.into(),
805                self.int_type.into(),
806            ],
807            false,
808        );
809        let function = self.module.add_function("barrier", fn_type, None);
810        let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
811        let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
812        function.add_attribute(AttributeLoc::Function, noinline);
813
814        let entry_block = self.context.append_basic_block(function, "entry");
815        let increment_block = self.context.append_basic_block(function, "increment");
816        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
817        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
818
819        self.fn_value_opt = Some(function);
820        self.builder.position_at_end(entry_block);
821
822        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
823        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
824        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
825        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
826
827        let nbarrier_equals_total_barriers = self
828            .builder
829            .build_int_compare(
830                IntPredicate::EQ,
831                barrier_num,
832                total_barriers,
833                "nbarrier_equals_total_barriers",
834            )
835            .unwrap();
836        // branch to barrier_done if nbarrier == total_barriers
837        self.builder.build_conditional_branch(
838            nbarrier_equals_total_barriers,
839            barrier_done_block,
840            increment_block,
841        )?;
842        self.builder.position_at_end(increment_block);
843
844        let barrier_num_times_thread_count = self
845            .builder
846            .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
847            .unwrap();
848
849        // Atomically increment the barrier counter
850        let i32_type = self.context.i32_type();
851        let one = i32_type.const_int(1, false);
852        self.builder.build_atomicrmw(
853            AtomicRMWBinOp::Add,
854            thread_counter,
855            one,
856            AtomicOrdering::Monotonic,
857        )?;
858
859        // wait_loop:
860        self.builder.build_unconditional_branch(wait_loop_block)?;
861        self.builder.position_at_end(wait_loop_block);
862
863        let current_value = self
864            .builder
865            .build_load(i32_type, thread_counter, "current_value")?
866            .into_int_value();
867
868        current_value
869            .as_instruction_value()
870            .unwrap()
871            .set_atomic_ordering(AtomicOrdering::Monotonic)
872            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
873
874        let all_threads_done = self.builder.build_int_compare(
875            IntPredicate::UGE,
876            current_value,
877            barrier_num_times_thread_count,
878            "all_threads_done",
879        )?;
880
881        self.builder.build_conditional_branch(
882            all_threads_done,
883            barrier_done_block,
884            wait_loop_block,
885        )?;
886        self.builder.position_at_end(barrier_done_block);
887
888        self.builder.build_return(None)?;
889
890        if function.verify(true) {
891            self.functions.insert("barrier".to_owned(), function);
892            Ok(function)
893        } else {
894            function.print_to_stderr();
895            unsafe {
896                function.delete();
897            }
898            Err(anyhow!("Invalid generated function."))
899        }
900    }
901
902    fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
903        self.clear();
904        let void_type = self.context.void_type();
905        let fn_type = void_type.fn_type(
906            &[
907                self.int_type.into(),
908                self.int_type.into(),
909                self.int_type.into(),
910            ],
911            false,
912        );
913        let function = self.module.add_function("barrier_grad", fn_type, None);
914
915        let entry_block = self.context.append_basic_block(function, "entry");
916        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
917        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
918
919        self.fn_value_opt = Some(function);
920        self.builder.position_at_end(entry_block);
921
922        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
923        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
924        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
925        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
926
927        let twice_total_barriers = self
928            .builder
929            .build_int_mul(
930                total_barriers,
931                self.int_type.const_int(2, false),
932                "twice_total_barriers",
933            )
934            .unwrap();
935        let twice_total_barriers_minus_barrier_num = self
936            .builder
937            .build_int_sub(
938                twice_total_barriers,
939                barrier_num,
940                "twice_total_barriers_minus_barrier_num",
941            )
942            .unwrap();
943        let twice_total_barriers_minus_barrier_num_times_thread_count = self
944            .builder
945            .build_int_mul(
946                twice_total_barriers_minus_barrier_num,
947                thread_count,
948                "twice_total_barriers_minus_barrier_num_times_thread_count",
949            )
950            .unwrap();
951
952        // Atomically increment the barrier counter
953        let i32_type = self.context.i32_type();
954        let one = i32_type.const_int(1, false);
955        self.builder.build_atomicrmw(
956            AtomicRMWBinOp::Add,
957            thread_counter,
958            one,
959            AtomicOrdering::Monotonic,
960        )?;
961
962        // wait_loop:
963        self.builder.build_unconditional_branch(wait_loop_block)?;
964        self.builder.position_at_end(wait_loop_block);
965
966        let current_value = self
967            .builder
968            .build_load(i32_type, thread_counter, "current_value")?
969            .into_int_value();
970        current_value
971            .as_instruction_value()
972            .unwrap()
973            .set_atomic_ordering(AtomicOrdering::Monotonic)
974            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
975
976        let all_threads_done = self.builder.build_int_compare(
977            IntPredicate::UGE,
978            current_value,
979            twice_total_barriers_minus_barrier_num_times_thread_count,
980            "all_threads_done",
981        )?;
982
983        self.builder.build_conditional_branch(
984            all_threads_done,
985            barrier_done_block,
986            wait_loop_block,
987        )?;
988        self.builder.position_at_end(barrier_done_block);
989
990        self.builder.build_return(None)?;
991
992        if function.verify(true) {
993            self.functions.insert("barrier_grad".to_owned(), function);
994            Ok(function)
995        } else {
996            function.print_to_stderr();
997            unsafe {
998                function.delete();
999            }
1000            Err(anyhow!("Invalid generated function."))
1001        }
1002    }
1003
1004    fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1005        if !self.threaded {
1006            return;
1007        }
1008        let thread_dim = self.get_param("thread_dim");
1009        let thread_dim = self
1010            .builder
1011            .build_load(self.int_type, *thread_dim, "thread_dim")
1012            .unwrap()
1013            .into_int_value();
1014        let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1015        let total_barriers = self.int_type.const_int(total_barriers, false);
1016        let barrier = self.get_function("barrier").unwrap();
1017        self.builder
1018            .build_call(
1019                barrier,
1020                &[
1021                    BasicMetadataValueEnum::IntValue(nbarrier),
1022                    BasicMetadataValueEnum::IntValue(total_barriers),
1023                    BasicMetadataValueEnum::IntValue(thread_dim),
1024                ],
1025                "barrier",
1026            )
1027            .unwrap();
1028    }
1029
1030    fn jit_threading_limits(
1031        &mut self,
1032        size: IntValue<'ctx>,
1033    ) -> Result<(
1034        IntValue<'ctx>,
1035        IntValue<'ctx>,
1036        BasicBlock<'ctx>,
1037        BasicBlock<'ctx>,
1038    )> {
1039        let one = self.int_type.const_int(1, false);
1040        let thread_id = self.get_param("thread_id");
1041        let thread_id = self
1042            .builder
1043            .build_load(self.int_type, *thread_id, "thread_id")
1044            .unwrap()
1045            .into_int_value();
1046        let thread_dim = self.get_param("thread_dim");
1047        let thread_dim = self
1048            .builder
1049            .build_load(self.int_type, *thread_dim, "thread_dim")
1050            .unwrap()
1051            .into_int_value();
1052
1053        // start index is i * size / thread_dim
1054        let i_times_size = self
1055            .builder
1056            .build_int_mul(thread_id, size, "i_times_size")?;
1057        let start = self
1058            .builder
1059            .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1060
1061        // the ending index for thread i is (i+1) * size / thread_dim
1062        let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1063        let i_plus_one_times_size =
1064            self.builder
1065                .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1066        let end = self
1067            .builder
1068            .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1069
1070        let test_done = self.builder.get_insert_block().unwrap();
1071        let next_block = self
1072            .context
1073            .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1074        self.builder.position_at_end(next_block);
1075
1076        Ok((start, end, test_done, next_block))
1077    }
1078
1079    fn jit_end_threading(
1080        &mut self,
1081        start: IntValue<'ctx>,
1082        end: IntValue<'ctx>,
1083        test_done: BasicBlock<'ctx>,
1084        next: BasicBlock<'ctx>,
1085    ) -> Result<()> {
1086        let exit = self
1087            .context
1088            .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1089        self.builder.build_unconditional_branch(exit)?;
1090        self.builder.position_at_end(test_done);
1091        // done if start == end
1092        let done = self
1093            .builder
1094            .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1095        self.builder.build_conditional_branch(done, exit, next)?;
1096        self.builder.position_at_end(exit);
1097        Ok(())
1098    }
1099
1100    pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1101        self.module.write_bitcode_to_path(path);
1102    }
1103
1104    fn insert_constants(&mut self, model: &DiscreteModel) {
1105        if let Some(constants) = self.globals.constants.as_ref() {
1106            self.insert_param("constants", constants.as_pointer_value());
1107            for tensor in model.constant_defns() {
1108                self.insert_tensor(tensor, true);
1109            }
1110        }
1111    }
1112
1113    fn insert_data(&mut self, model: &DiscreteModel) {
1114        self.insert_constants(model);
1115
1116        for tensor in model.inputs() {
1117            self.insert_tensor(tensor, false);
1118        }
1119        for tensor in model.input_dep_defns() {
1120            self.insert_tensor(tensor, false);
1121        }
1122        for tensor in model.time_dep_defns() {
1123            self.insert_tensor(tensor, false);
1124        }
1125        for tensor in model.state_dep_defns() {
1126            self.insert_tensor(tensor, false);
1127        }
1128    }
1129
1130    fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1131        context.ptr_type(AddressSpace::default())
1132    }
1133
1134    fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1135        context.ptr_type(AddressSpace::default())
1136    }
1137
1138    fn insert_indices(&mut self) {
1139        if let Some(indices) = self.globals.indices.as_ref() {
1140            let i32_type = self.context.i32_type();
1141            let zero = i32_type.const_int(0, false);
1142            let ptr = unsafe {
1143                indices
1144                    .as_pointer_value()
1145                    .const_in_bounds_gep(i32_type, &[zero])
1146            };
1147            self.variables.insert("indices".to_owned(), ptr);
1148        }
1149    }
1150
1151    fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1152        self.variables.insert(name.to_owned(), value);
1153    }
1154
1155    fn build_gep<T: BasicType<'ctx>>(
1156        &self,
1157        ty: T,
1158        ptr: PointerValue<'ctx>,
1159        ordered_indexes: &[IntValue<'ctx>],
1160        name: &str,
1161    ) -> Result<PointerValue<'ctx>> {
1162        unsafe {
1163            self.builder
1164                .build_gep(ty, ptr, ordered_indexes, name)
1165                .map_err(|e| e.into())
1166        }
1167    }
1168
1169    fn build_load<T: BasicType<'ctx>>(
1170        &self,
1171        ty: T,
1172        ptr: PointerValue<'ctx>,
1173        name: &str,
1174    ) -> Result<BasicValueEnum<'ctx>> {
1175        self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1176    }
1177
1178    fn get_ptr_to_index<T: BasicType<'ctx>>(
1179        builder: &Builder<'ctx>,
1180        ty: T,
1181        ptr: &PointerValue<'ctx>,
1182        index: IntValue<'ctx>,
1183        name: &str,
1184    ) -> PointerValue<'ctx> {
1185        unsafe {
1186            builder
1187                .build_in_bounds_gep(ty, *ptr, &[index], name)
1188                .unwrap()
1189        }
1190    }
1191
1192    fn insert_state(&mut self, u: &Tensor) {
1193        let mut data_index = 0;
1194        for blk in u.elmts() {
1195            if let Some(name) = blk.name() {
1196                let ptr = self.variables.get("u").unwrap();
1197                let i = self
1198                    .context
1199                    .i32_type()
1200                    .const_int(data_index.try_into().unwrap(), false);
1201                let alloca = Self::get_ptr_to_index(
1202                    &self.create_entry_block_builder(),
1203                    self.real_type,
1204                    ptr,
1205                    i,
1206                    blk.name().unwrap(),
1207                );
1208                self.variables.insert(name.to_owned(), alloca);
1209            }
1210            data_index += blk.nnz();
1211        }
1212    }
1213    fn insert_dot_state(&mut self, dudt: &Tensor) {
1214        let mut data_index = 0;
1215        for blk in dudt.elmts() {
1216            if let Some(name) = blk.name() {
1217                let ptr = self.variables.get("dudt").unwrap();
1218                let i = self
1219                    .context
1220                    .i32_type()
1221                    .const_int(data_index.try_into().unwrap(), false);
1222                let alloca = Self::get_ptr_to_index(
1223                    &self.create_entry_block_builder(),
1224                    self.real_type,
1225                    ptr,
1226                    i,
1227                    blk.name().unwrap(),
1228                );
1229                self.variables.insert(name.to_owned(), alloca);
1230            }
1231            data_index += blk.nnz();
1232        }
1233    }
1234    fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1235        let var_name = if is_constant { "constants" } else { "data" };
1236        let ptr = *self.variables.get(var_name).unwrap();
1237        let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1238        let i = self
1239            .context
1240            .i32_type()
1241            .const_int(data_index.try_into().unwrap(), false);
1242        let alloca = Self::get_ptr_to_index(
1243            &self.create_entry_block_builder(),
1244            self.real_type,
1245            &ptr,
1246            i,
1247            tensor.name(),
1248        );
1249        self.variables.insert(tensor.name().to_owned(), alloca);
1250
1251        //insert any named blocks
1252        for blk in tensor.elmts() {
1253            if let Some(name) = blk.name() {
1254                let i = self
1255                    .context
1256                    .i32_type()
1257                    .const_int(data_index.try_into().unwrap(), false);
1258                let alloca = Self::get_ptr_to_index(
1259                    &self.create_entry_block_builder(),
1260                    self.real_type,
1261                    &ptr,
1262                    i,
1263                    name,
1264                );
1265                self.variables.insert(name.to_owned(), alloca);
1266            }
1267            // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index
1268            data_index += blk.nnz();
1269        }
1270    }
1271    fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1272        self.variables.get(name).unwrap()
1273    }
1274
1275    fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1276        self.variables.get(tensor.name()).unwrap()
1277    }
1278
1279    fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1280        match self.functions.get(name) {
1281            Some(&func) => Some(func),
1282            None => {
1283                let function = match name {
1284                    // support some llvm intrinsics
1285                    "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs"
1286                    | "copysign" | "pow" | "min" | "max" => {
1287                        let arg_len = 1;
1288                        let intrinsic_name = match name {
1289                            "min" => "minnum",
1290                            "max" => "maxnum",
1291                            "abs" => "fabs",
1292                            _ => name,
1293                        };
1294                        let llvm_name = format!("llvm.{}.{}", intrinsic_name, self.real_type_str);
1295                        let intrinsic = Intrinsic::find(&llvm_name).unwrap();
1296                        let ret_type = self.real_type;
1297
1298                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1299                            .map(|f| f.into())
1300                            .collect::<Vec<BasicTypeEnum>>();
1301                        // if we get an intrinsic, we don't need to add to the list of functions and can return early
1302                        return intrinsic.get_declaration(&self.module, args_types.as_slice());
1303                    }
1304                    // some custom functions
1305                    "sigmoid" => {
1306                        let arg_len = 1;
1307                        let ret_type = self.real_type;
1308
1309                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1310                            .map(|f| f.into())
1311                            .collect::<Vec<BasicMetadataTypeEnum>>();
1312
1313                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1314                        let fn_val = self.module.add_function(name, fn_type, None);
1315
1316                        for arg in fn_val.get_param_iter() {
1317                            arg.into_float_value().set_name("x");
1318                        }
1319
1320                        let current_block = self.builder.get_insert_block().unwrap();
1321                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1322                        self.builder.position_at_end(basic_block);
1323                        let x = fn_val.get_nth_param(0)?.into_float_value();
1324                        let one = self.real_type.const_float(1.0);
1325                        let negx = self.builder.build_float_neg(x, name).ok()?;
1326                        let exp = self.get_function("exp").unwrap();
1327                        let exp_negx = self
1328                            .builder
1329                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1330                            .ok()?;
1331                        let one_plus_exp_negx = self
1332                            .builder
1333                            .build_float_add(
1334                                exp_negx
1335                                    .try_as_basic_value()
1336                                    .left()
1337                                    .unwrap()
1338                                    .into_float_value(),
1339                                one,
1340                                name,
1341                            )
1342                            .ok()?;
1343                        let sigmoid = self
1344                            .builder
1345                            .build_float_div(one, one_plus_exp_negx, name)
1346                            .ok()?;
1347                        self.builder.build_return(Some(&sigmoid)).ok();
1348                        self.builder.position_at_end(current_block);
1349                        Some(fn_val)
1350                    }
1351                    "arcsinh" | "arccosh" => {
1352                        let arg_len = 1;
1353                        let ret_type = self.real_type;
1354
1355                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1356                            .map(|f| f.into())
1357                            .collect::<Vec<BasicMetadataTypeEnum>>();
1358
1359                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1360                        let fn_val = self.module.add_function(name, fn_type, None);
1361
1362                        for arg in fn_val.get_param_iter() {
1363                            arg.into_float_value().set_name("x");
1364                        }
1365
1366                        let current_block = self.builder.get_insert_block().unwrap();
1367                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1368                        self.builder.position_at_end(basic_block);
1369                        let x = fn_val.get_nth_param(0)?.into_float_value();
1370                        let one = match name {
1371                            "arccosh" => self.real_type.const_float(-1.0),
1372                            "arcsinh" => self.real_type.const_float(1.0),
1373                            _ => panic!("unknown function"),
1374                        };
1375                        let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1376                        let one_plus_x_squared =
1377                            self.builder.build_float_add(x_squared, one, name).ok()?;
1378                        let sqrt = self.get_function("sqrt").unwrap();
1379                        let sqrt_one_plus_x_squared = self
1380                            .builder
1381                            .build_call(
1382                                sqrt,
1383                                &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1384                                name,
1385                            )
1386                            .unwrap()
1387                            .try_as_basic_value()
1388                            .left()
1389                            .unwrap()
1390                            .into_float_value();
1391                        let x_plus_sqrt_one_plus_x_squared = self
1392                            .builder
1393                            .build_float_add(x, sqrt_one_plus_x_squared, name)
1394                            .ok()?;
1395                        let ln = self.get_function("log").unwrap();
1396                        let result = self
1397                            .builder
1398                            .build_call(
1399                                ln,
1400                                &[BasicMetadataValueEnum::FloatValue(
1401                                    x_plus_sqrt_one_plus_x_squared,
1402                                )],
1403                                name,
1404                            )
1405                            .unwrap()
1406                            .try_as_basic_value()
1407                            .left()
1408                            .unwrap()
1409                            .into_float_value();
1410                        self.builder.build_return(Some(&result)).ok();
1411                        self.builder.position_at_end(current_block);
1412                        Some(fn_val)
1413                    }
1414                    "heaviside" => {
1415                        let arg_len = 1;
1416                        let ret_type = self.real_type;
1417
1418                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1419                            .map(|f| f.into())
1420                            .collect::<Vec<BasicMetadataTypeEnum>>();
1421
1422                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1423                        let fn_val = self.module.add_function(name, fn_type, None);
1424
1425                        for arg in fn_val.get_param_iter() {
1426                            arg.into_float_value().set_name("x");
1427                        }
1428
1429                        let current_block = self.builder.get_insert_block().unwrap();
1430                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1431                        self.builder.position_at_end(basic_block);
1432                        let x = fn_val.get_nth_param(0)?.into_float_value();
1433                        let zero = self.real_type.const_float(0.0);
1434                        let one = self.real_type.const_float(1.0);
1435                        let result = self
1436                            .builder
1437                            .build_select(
1438                                self.builder
1439                                    .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1440                                    .unwrap(),
1441                                one,
1442                                zero,
1443                                name,
1444                            )
1445                            .ok()?;
1446                        self.builder.build_return(Some(&result)).ok();
1447                        self.builder.position_at_end(current_block);
1448                        Some(fn_val)
1449                    }
1450                    "tanh" | "sinh" | "cosh" => {
1451                        let arg_len = 1;
1452                        let ret_type = self.real_type;
1453
1454                        let args_types = std::iter::repeat_n(ret_type, arg_len)
1455                            .map(|f| f.into())
1456                            .collect::<Vec<BasicMetadataTypeEnum>>();
1457
1458                        let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1459                        let fn_val = self.module.add_function(name, fn_type, None);
1460
1461                        for arg in fn_val.get_param_iter() {
1462                            arg.into_float_value().set_name("x");
1463                        }
1464
1465                        let current_block = self.builder.get_insert_block().unwrap();
1466                        let basic_block = self.context.append_basic_block(fn_val, "entry");
1467                        self.builder.position_at_end(basic_block);
1468                        let x = fn_val.get_nth_param(0)?.into_float_value();
1469                        let negx = self.builder.build_float_neg(x, name).ok()?;
1470                        let exp = self.get_function("exp").unwrap();
1471                        let exp_negx = self
1472                            .builder
1473                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1474                            .ok()?;
1475                        let expx = self
1476                            .builder
1477                            .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1478                            .ok()?;
1479                        let expx_minus_exp_negx = self
1480                            .builder
1481                            .build_float_sub(
1482                                expx.try_as_basic_value().left().unwrap().into_float_value(),
1483                                exp_negx
1484                                    .try_as_basic_value()
1485                                    .left()
1486                                    .unwrap()
1487                                    .into_float_value(),
1488                                name,
1489                            )
1490                            .ok()?;
1491                        let expx_plus_exp_negx = self
1492                            .builder
1493                            .build_float_add(
1494                                expx.try_as_basic_value().left().unwrap().into_float_value(),
1495                                exp_negx
1496                                    .try_as_basic_value()
1497                                    .left()
1498                                    .unwrap()
1499                                    .into_float_value(),
1500                                name,
1501                            )
1502                            .ok()?;
1503                        let result = match name {
1504                            "tanh" => self
1505                                .builder
1506                                .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1507                                .ok()?,
1508                            "sinh" => self
1509                                .builder
1510                                .build_float_div(
1511                                    expx_minus_exp_negx,
1512                                    self.real_type.const_float(2.0),
1513                                    name,
1514                                )
1515                                .ok()?,
1516                            "cosh" => self
1517                                .builder
1518                                .build_float_div(
1519                                    expx_plus_exp_negx,
1520                                    self.real_type.const_float(2.0),
1521                                    name,
1522                                )
1523                                .ok()?,
1524                            _ => panic!("unknown function"),
1525                        };
1526                        self.builder.build_return(Some(&result)).ok();
1527                        self.builder.position_at_end(current_block);
1528                        Some(fn_val)
1529                    }
1530                    _ => None,
1531                }?;
1532                self.functions.insert(name.to_owned(), function);
1533                Some(function)
1534            }
1535        }
1536    }
1537    /// Returns the `FunctionValue` representing the function being compiled.
1538    #[inline]
1539    fn fn_value(&self) -> FunctionValue<'ctx> {
1540        self.fn_value_opt.unwrap()
1541    }
1542
1543    #[inline]
1544    fn tensor_ptr(&self) -> PointerValue<'ctx> {
1545        self.tensor_ptr_opt.unwrap()
1546    }
1547
1548    /// Creates a new builder in the entry block of the function.
1549    fn create_entry_block_builder(&self) -> Builder<'ctx> {
1550        let builder = self.context.create_builder();
1551        let entry = self.fn_value().get_first_basic_block().unwrap();
1552        match entry.get_first_instruction() {
1553            Some(first_instr) => builder.position_before(&first_instr),
1554            None => builder.position_at_end(entry),
1555        }
1556        builder
1557    }
1558
1559    fn jit_compile_scalar(
1560        &mut self,
1561        a: &Tensor,
1562        res_ptr_opt: Option<PointerValue<'ctx>>,
1563    ) -> Result<PointerValue<'ctx>> {
1564        let res_type = self.real_type;
1565        let res_ptr = match res_ptr_opt {
1566            Some(ptr) => ptr,
1567            None => self
1568                .create_entry_block_builder()
1569                .build_alloca(res_type, a.name())?,
1570        };
1571        let name = a.name();
1572        let elmt = a.elmts().first().unwrap();
1573
1574        // if threaded then only the first thread will evaluate the scalar
1575        let curr_block = self.builder.get_insert_block().unwrap();
1576        let mut next_block_opt = None;
1577        if self.threaded {
1578            let next_block = self.context.append_basic_block(self.fn_value(), "next");
1579            self.builder.position_at_end(next_block);
1580            next_block_opt = Some(next_block);
1581        }
1582
1583        let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, None)?;
1584        self.builder.build_store(res_ptr, float_value)?;
1585
1586        // complete the threading block
1587        if self.threaded {
1588            let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1589            self.builder.build_unconditional_branch(exit_block)?;
1590            self.builder.position_at_end(curr_block);
1591
1592            let thread_id = self.get_param("thread_id");
1593            let thread_id = self
1594                .builder
1595                .build_load(self.int_type, *thread_id, "thread_id")
1596                .unwrap()
1597                .into_int_value();
1598            let is_first_thread = self.builder.build_int_compare(
1599                IntPredicate::EQ,
1600                thread_id,
1601                self.int_type.const_zero(),
1602                "is_first_thread",
1603            )?;
1604            self.builder.build_conditional_branch(
1605                is_first_thread,
1606                next_block_opt.unwrap(),
1607                exit_block,
1608            )?;
1609
1610            self.builder.position_at_end(exit_block);
1611        }
1612
1613        Ok(res_ptr)
1614    }
1615
1616    fn jit_compile_tensor(
1617        &mut self,
1618        a: &Tensor,
1619        res_ptr_opt: Option<PointerValue<'ctx>>,
1620    ) -> Result<PointerValue<'ctx>> {
1621        // treat scalar as a special case
1622        if a.rank() == 0 {
1623            return self.jit_compile_scalar(a, res_ptr_opt);
1624        }
1625
1626        let res_type = self.real_type;
1627        let res_ptr = match res_ptr_opt {
1628            Some(ptr) => ptr,
1629            None => self
1630                .create_entry_block_builder()
1631                .build_alloca(res_type, a.name())?,
1632        };
1633
1634        // set up the tensor storage pointer and index into this data
1635        self.tensor_ptr_opt = Some(res_ptr);
1636
1637        for (i, blk) in a.elmts().iter().enumerate() {
1638            let default = format!("{}-{}", a.name(), i);
1639            let name = blk.name().unwrap_or(default.as_str());
1640            self.jit_compile_block(name, a, blk)?;
1641        }
1642        Ok(res_ptr)
1643    }
1644
1645    fn jit_compile_block(&mut self, name: &str, tensor: &Tensor, elmt: &TensorBlock) -> Result<()> {
1646        let translation = Translation::new(
1647            elmt.expr_layout(),
1648            elmt.layout(),
1649            elmt.start(),
1650            tensor.layout_ptr(),
1651        );
1652
1653        if elmt.expr_layout().is_dense() {
1654            self.jit_compile_dense_block(name, elmt, &translation)
1655        } else if elmt.expr_layout().is_diagonal() {
1656            self.jit_compile_diagonal_block(name, elmt, &translation)
1657        } else if elmt.expr_layout().is_sparse() {
1658            match translation.source {
1659                TranslationFrom::SparseContraction { .. } => {
1660                    self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1661                }
1662                _ => self.jit_compile_sparse_block(name, elmt, &translation),
1663            }
1664        } else {
1665            return Err(anyhow!(
1666                "unsupported block layout: {:?}",
1667                elmt.expr_layout()
1668            ));
1669        }
1670    }
1671
1672    // for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index
1673    fn jit_compile_dense_block(
1674        &mut self,
1675        name: &str,
1676        elmt: &TensorBlock,
1677        translation: &Translation,
1678    ) -> Result<()> {
1679        let int_type = self.int_type;
1680
1681        let mut preblock = self.builder.get_insert_block().unwrap();
1682        let expr_rank = elmt.expr_layout().rank();
1683        let expr_shape = elmt
1684            .expr_layout()
1685            .shape()
1686            .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1687        let one = int_type.const_int(1, false);
1688
1689        let mut expr_strides = vec![1; expr_rank];
1690        if expr_rank > 0 {
1691            for i in (0..expr_rank - 1).rev() {
1692                expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1693            }
1694        }
1695        let expr_strides = expr_strides
1696            .iter()
1697            .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1698            .collect::<Vec<IntValue>>();
1699
1700        // setup indices, loop through the nested loops
1701        let mut indices = Vec::new();
1702        let mut blocks = Vec::new();
1703
1704        // allocate the contract sum if needed
1705        let (contract_sum, contract_by, contract_strides) =
1706            if let TranslationFrom::DenseContraction {
1707                contract_by,
1708                contract_len: _,
1709            } = translation.source
1710            {
1711                let contract_rank = expr_rank - contract_by;
1712                let mut contract_strides = vec![1; contract_rank];
1713                for i in (0..contract_rank - 1).rev() {
1714                    contract_strides[i] =
1715                        contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1716                }
1717                let contract_strides = contract_strides
1718                    .iter()
1719                    .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1720                    .collect::<Vec<IntValue>>();
1721                (
1722                    Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1723                    contract_by,
1724                    Some(contract_strides),
1725                )
1726            } else {
1727                (None, 0, None)
1728            };
1729
1730        // we will thread the output loop, except if we are contracting to a scalar
1731        let (thread_start, thread_end, test_done, next) = if self.threaded {
1732            let (start, end, test_done, next) =
1733                self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1734            preblock = next;
1735            (Some(start), Some(end), Some(test_done), Some(next))
1736        } else {
1737            (None, None, None, None)
1738        };
1739
1740        for i in 0..expr_rank {
1741            let block = self.context.append_basic_block(self.fn_value(), name);
1742            self.builder.build_unconditional_branch(block)?;
1743            self.builder.position_at_end(block);
1744
1745            let start_index = if i == 0 && self.threaded {
1746                thread_start.unwrap()
1747            } else {
1748                self.int_type.const_zero()
1749            };
1750
1751            let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1752            curr_index.add_incoming(&[(&start_index, preblock)]);
1753
1754            if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1755                self.builder
1756                    .build_store(contract_sum.unwrap(), self.real_type.const_zero())?;
1757            }
1758
1759            indices.push(curr_index);
1760            blocks.push(block);
1761            preblock = block;
1762        }
1763
1764        let indices_int: Vec<IntValue> = indices
1765            .iter()
1766            .map(|i| i.as_basic_value().into_int_value())
1767            .collect();
1768
1769        let float_value =
1770            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, None)?;
1771
1772        if contract_sum.is_some() {
1773            let contract_sum_value = self
1774                .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1775                .into_float_value();
1776            let new_contract_sum_value = self.builder.build_float_add(
1777                contract_sum_value,
1778                float_value,
1779                "new_contract_sum",
1780            )?;
1781            self.builder
1782                .build_store(contract_sum.unwrap(), new_contract_sum_value)?;
1783        } else {
1784            let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
1785                self.int_type.const_zero(),
1786                |acc, (i, s)| {
1787                    let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
1788                    self.builder.build_int_add(acc, tmp, "acc").unwrap()
1789                },
1790            );
1791            preblock = self.jit_compile_broadcast_and_store(
1792                name,
1793                elmt,
1794                float_value,
1795                expr_index,
1796                translation,
1797                preblock,
1798            )?;
1799        }
1800
1801        // unwind the nested loops
1802        for i in (0..expr_rank).rev() {
1803            // increment index
1804            let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
1805            indices[i].add_incoming(&[(&next_index, preblock)]);
1806
1807            if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1808                let contract_sum_value = self
1809                    .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1810                    .into_float_value();
1811                let contract_strides = contract_strides.as_ref().unwrap();
1812                let elmt_index = indices_int
1813                    .iter()
1814                    .take(contract_strides.len())
1815                    .zip(contract_strides.iter())
1816                    .fold(self.int_type.const_zero(), |acc, (i, s)| {
1817                        let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
1818                        self.builder.build_int_add(acc, tmp, "acc").unwrap()
1819                    });
1820                self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?;
1821            }
1822
1823            let end_index = if i == 0 && self.threaded {
1824                thread_end.unwrap()
1825            } else {
1826                expr_shape[i]
1827            };
1828
1829            // loop condition
1830            let loop_while =
1831                self.builder
1832                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
1833            let block = self.context.append_basic_block(self.fn_value(), name);
1834            self.builder
1835                .build_conditional_branch(loop_while, blocks[i], block)?;
1836            self.builder.position_at_end(block);
1837            preblock = block;
1838        }
1839
1840        if self.threaded {
1841            self.jit_end_threading(
1842                thread_start.unwrap(),
1843                thread_end.unwrap(),
1844                test_done.unwrap(),
1845                next.unwrap(),
1846            )?;
1847        }
1848        Ok(())
1849    }
1850
1851    fn jit_compile_sparse_contraction_block(
1852        &mut self,
1853        name: &str,
1854        elmt: &TensorBlock,
1855        translation: &Translation,
1856    ) -> Result<()> {
1857        match translation.source {
1858            TranslationFrom::SparseContraction { .. } => {}
1859            _ => {
1860                panic!("expected sparse contraction")
1861            }
1862        }
1863        let int_type = self.int_type;
1864
1865        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
1866        let translation_index = self
1867            .layout
1868            .get_translation_index(elmt.expr_layout(), elmt.layout())
1869            .unwrap();
1870        let translation_index = translation_index + translation.get_from_index_in_data_layout();
1871
1872        let final_contract_index =
1873            int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
1874        let (thread_start, thread_end, test_done, next) = if self.threaded {
1875            let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
1876            (Some(start), Some(end), Some(test_done), Some(next))
1877        } else {
1878            (None, None, None, None)
1879        };
1880
1881        let preblock = self.builder.get_insert_block().unwrap();
1882        let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
1883
1884        // loop through each contraction
1885        let block = self.context.append_basic_block(self.fn_value(), name);
1886        self.builder.build_unconditional_branch(block)?;
1887        self.builder.position_at_end(block);
1888
1889        let contract_index = self.builder.build_phi(int_type, "i")?;
1890        let contract_start = if self.threaded {
1891            thread_start.unwrap()
1892        } else {
1893            int_type.const_zero()
1894        };
1895        contract_index.add_incoming(&[(&contract_start, preblock)]);
1896
1897        let start_index = self.builder.build_int_add(
1898            int_type.const_int(translation_index.try_into().unwrap(), false),
1899            self.builder.build_int_mul(
1900                int_type.const_int(2, false),
1901                contract_index.as_basic_value().into_int_value(),
1902                name,
1903            )?,
1904            name,
1905        )?;
1906        let end_index =
1907            self.builder
1908                .build_int_add(start_index, int_type.const_int(1, false), name)?;
1909        let start_ptr = self.build_gep(
1910            self.int_type,
1911            *self.get_param("indices"),
1912            &[start_index],
1913            "start_index_ptr",
1914        )?;
1915        let start_contract = self
1916            .build_load(self.int_type, start_ptr, "start")?
1917            .into_int_value();
1918        let end_ptr = self.build_gep(
1919            self.int_type,
1920            *self.get_param("indices"),
1921            &[end_index],
1922            "end_index_ptr",
1923        )?;
1924        let end_contract = self
1925            .build_load(self.int_type, end_ptr, "end")?
1926            .into_int_value();
1927
1928        // initialise the contract sum
1929        self.builder
1930            .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
1931
1932        // loop through each element in the contraction
1933        let contract_block = self
1934            .context
1935            .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
1936        self.builder.build_unconditional_branch(contract_block)?;
1937        self.builder.position_at_end(contract_block);
1938
1939        let expr_index_phi = self.builder.build_phi(int_type, "j")?;
1940        expr_index_phi.add_incoming(&[(&start_contract, block)]);
1941
1942        // loop body - load index from layout
1943        let expr_index = expr_index_phi.as_basic_value().into_int_value();
1944        let elmt_index_mult_rank = self.builder.build_int_mul(
1945            expr_index,
1946            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
1947            name,
1948        )?;
1949        let indices_int = (0..elmt.expr_layout().rank())
1950            .map(|i| {
1951                let layout_index_plus_offset =
1952                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
1953                let curr_index = self.builder.build_int_add(
1954                    elmt_index_mult_rank,
1955                    layout_index_plus_offset,
1956                    name,
1957                )?;
1958                let ptr = Self::get_ptr_to_index(
1959                    &self.builder,
1960                    self.int_type,
1961                    self.get_param("indices"),
1962                    curr_index,
1963                    name,
1964                );
1965                let index = self.build_load(self.int_type, ptr, name)?.into_int_value();
1966                Ok(index)
1967            })
1968            .collect::<Result<Vec<_>, anyhow::Error>>()?;
1969
1970        // loop body - eval expression and increment sum
1971        let float_value = self.jit_compile_expr(
1972            name,
1973            elmt.expr(),
1974            indices_int.as_slice(),
1975            elmt,
1976            Some(expr_index),
1977        )?;
1978        let contract_sum_value = self
1979            .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
1980            .into_float_value();
1981        let new_contract_sum_value =
1982            self.builder
1983                .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
1984        self.builder
1985            .build_store(contract_sum_ptr, new_contract_sum_value)?;
1986
1987        // increment contract loop index
1988        let next_elmt_index =
1989            self.builder
1990                .build_int_add(expr_index, int_type.const_int(1, false), name)?;
1991        expr_index_phi.add_incoming(&[(&next_elmt_index, contract_block)]);
1992
1993        // contract loop condition
1994        let loop_while = self.builder.build_int_compare(
1995            IntPredicate::ULT,
1996            next_elmt_index,
1997            end_contract,
1998            name,
1999        )?;
2000        let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2001        self.builder
2002            .build_conditional_branch(loop_while, contract_block, post_contract_block)?;
2003        self.builder.position_at_end(post_contract_block);
2004
2005        // store the result
2006        self.jit_compile_store(
2007            name,
2008            elmt,
2009            contract_index.as_basic_value().into_int_value(),
2010            new_contract_sum_value,
2011            translation,
2012        )?;
2013
2014        // increment outer loop index
2015        let next_contract_index = self.builder.build_int_add(
2016            contract_index.as_basic_value().into_int_value(),
2017            int_type.const_int(1, false),
2018            name,
2019        )?;
2020        contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2021
2022        // outer loop condition
2023        let loop_while = self.builder.build_int_compare(
2024            IntPredicate::ULT,
2025            next_contract_index,
2026            thread_end.unwrap_or(final_contract_index),
2027            name,
2028        )?;
2029        let post_block = self.context.append_basic_block(self.fn_value(), name);
2030        self.builder
2031            .build_conditional_branch(loop_while, block, post_block)?;
2032        self.builder.position_at_end(post_block);
2033
2034        if self.threaded {
2035            self.jit_end_threading(
2036                thread_start.unwrap(),
2037                thread_end.unwrap(),
2038                test_done.unwrap(),
2039                next.unwrap(),
2040            )?;
2041        }
2042
2043        Ok(())
2044    }
2045
2046    // for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index
2047    // TODO: havn't implemented contractions yet
2048    fn jit_compile_sparse_block(
2049        &mut self,
2050        name: &str,
2051        elmt: &TensorBlock,
2052        translation: &Translation,
2053    ) -> Result<()> {
2054        let int_type = self.int_type;
2055
2056        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2057        let start_index = int_type.const_int(0, false);
2058        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2059
2060        let (thread_start, thread_end, test_done, next) = if self.threaded {
2061            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2062            (Some(start), Some(end), Some(test_done), Some(next))
2063        } else {
2064            (None, None, None, None)
2065        };
2066
2067        // loop through the non-zero elements
2068        let preblock = self.builder.get_insert_block().unwrap();
2069        let mut block = self.context.append_basic_block(self.fn_value(), name);
2070        self.builder.build_unconditional_branch(block)?;
2071        self.builder.position_at_end(block);
2072
2073        let curr_index = self.builder.build_phi(int_type, "i")?;
2074        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2075
2076        // loop body - load index from layout
2077        let elmt_index = curr_index.as_basic_value().into_int_value();
2078        let elmt_index_mult_rank = self.builder.build_int_mul(
2079            elmt_index,
2080            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2081            name,
2082        )?;
2083        let indices_int = (0..elmt.expr_layout().rank())
2084            .map(|i| {
2085                let layout_index_plus_offset =
2086                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
2087                let curr_index = self.builder.build_int_add(
2088                    elmt_index_mult_rank,
2089                    layout_index_plus_offset,
2090                    name,
2091                )?;
2092                let ptr = Self::get_ptr_to_index(
2093                    &self.builder,
2094                    self.int_type,
2095                    self.get_param("indices"),
2096                    curr_index,
2097                    name,
2098                );
2099                Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2100            })
2101            .collect::<Result<Vec<_>, anyhow::Error>>()?;
2102
2103        // loop body - eval expression
2104        let float_value = self.jit_compile_expr(
2105            name,
2106            elmt.expr(),
2107            indices_int.as_slice(),
2108            elmt,
2109            Some(elmt_index),
2110        )?;
2111
2112        block = self.jit_compile_broadcast_and_store(
2113            name,
2114            elmt,
2115            float_value,
2116            elmt_index,
2117            translation,
2118            block,
2119        )?;
2120
2121        // increment loop index
2122        let one = int_type.const_int(1, false);
2123        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2124        curr_index.add_incoming(&[(&next_index, block)]);
2125
2126        // loop condition
2127        let loop_while = self.builder.build_int_compare(
2128            IntPredicate::ULT,
2129            next_index,
2130            thread_end.unwrap_or(end_index),
2131            name,
2132        )?;
2133        let post_block = self.context.append_basic_block(self.fn_value(), name);
2134        self.builder
2135            .build_conditional_branch(loop_while, block, post_block)?;
2136        self.builder.position_at_end(post_block);
2137
2138        if self.threaded {
2139            self.jit_end_threading(
2140                thread_start.unwrap(),
2141                thread_end.unwrap(),
2142                test_done.unwrap(),
2143                next.unwrap(),
2144            )?;
2145        }
2146
2147        Ok(())
2148    }
2149
2150    // for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index
2151    fn jit_compile_diagonal_block(
2152        &mut self,
2153        name: &str,
2154        elmt: &TensorBlock,
2155        translation: &Translation,
2156    ) -> Result<()> {
2157        let int_type = self.int_type;
2158
2159        let start_index = int_type.const_int(0, false);
2160        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2161
2162        let (thread_start, thread_end, test_done, next) = if self.threaded {
2163            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2164            (Some(start), Some(end), Some(test_done), Some(next))
2165        } else {
2166            (None, None, None, None)
2167        };
2168
2169        // loop through the non-zero elements
2170        let preblock = self.builder.get_insert_block().unwrap();
2171        let mut block = self.context.append_basic_block(self.fn_value(), name);
2172        self.builder.build_unconditional_branch(block)?;
2173        self.builder.position_at_end(block);
2174
2175        let curr_index = self.builder.build_phi(int_type, "i")?;
2176        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2177
2178        // loop body - index is just the same for each element
2179        let elmt_index = curr_index.as_basic_value().into_int_value();
2180        let indices_int: Vec<IntValue> =
2181            (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2182
2183        // loop body - eval expression
2184        let float_value = self.jit_compile_expr(
2185            name,
2186            elmt.expr(),
2187            indices_int.as_slice(),
2188            elmt,
2189            Some(elmt_index),
2190        )?;
2191
2192        // loop body - store result
2193        block = self.jit_compile_broadcast_and_store(
2194            name,
2195            elmt,
2196            float_value,
2197            elmt_index,
2198            translation,
2199            block,
2200        )?;
2201
2202        // increment loop index
2203        let one = int_type.const_int(1, false);
2204        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2205        curr_index.add_incoming(&[(&next_index, block)]);
2206
2207        // loop condition
2208        let loop_while = self.builder.build_int_compare(
2209            IntPredicate::ULT,
2210            next_index,
2211            thread_end.unwrap_or(end_index),
2212            name,
2213        )?;
2214        let post_block = self.context.append_basic_block(self.fn_value(), name);
2215        self.builder
2216            .build_conditional_branch(loop_while, block, post_block)?;
2217        self.builder.position_at_end(post_block);
2218
2219        if self.threaded {
2220            self.jit_end_threading(
2221                thread_start.unwrap(),
2222                thread_end.unwrap(),
2223                test_done.unwrap(),
2224                next.unwrap(),
2225            )?;
2226        }
2227
2228        Ok(())
2229    }
2230
2231    fn jit_compile_broadcast_and_store(
2232        &mut self,
2233        name: &str,
2234        elmt: &TensorBlock,
2235        float_value: FloatValue<'ctx>,
2236        expr_index: IntValue<'ctx>,
2237        translation: &Translation,
2238        pre_block: BasicBlock<'ctx>,
2239    ) -> Result<BasicBlock<'ctx>> {
2240        let int_type = self.int_type;
2241        let one = int_type.const_int(1, false);
2242        let zero = int_type.const_int(0, false);
2243        match translation.source {
2244            TranslationFrom::Broadcast {
2245                broadcast_by: _,
2246                broadcast_len,
2247            } => {
2248                let bcast_start_index = zero;
2249                let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2250
2251                // setup loop block
2252                let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2253                self.builder.build_unconditional_branch(bcast_block)?;
2254                self.builder.position_at_end(bcast_block);
2255                let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2256                bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2257
2258                // store value
2259                let store_index = self.builder.build_int_add(
2260                    self.builder
2261                        .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2262                    bcast_index.as_basic_value().into_int_value(),
2263                    "bcast_store_index",
2264                )?;
2265                self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2266
2267                // increment index
2268                let bcast_next_index = self.builder.build_int_add(
2269                    bcast_index.as_basic_value().into_int_value(),
2270                    one,
2271                    name,
2272                )?;
2273                bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2274
2275                // loop condition
2276                let bcast_cond = self.builder.build_int_compare(
2277                    IntPredicate::ULT,
2278                    bcast_next_index,
2279                    bcast_end_index,
2280                    "broadcast_cond",
2281                )?;
2282                let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2283                self.builder
2284                    .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2285                self.builder.position_at_end(post_bcast_block);
2286
2287                // return the current block for later
2288                Ok(post_bcast_block)
2289            }
2290            TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2291                self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2292                Ok(pre_block)
2293            }
2294            _ => Err(anyhow!("Invalid translation")),
2295        }
2296    }
2297
2298    fn jit_compile_store(
2299        &mut self,
2300        name: &str,
2301        elmt: &TensorBlock,
2302        store_index: IntValue<'ctx>,
2303        float_value: FloatValue<'ctx>,
2304        translation: &Translation,
2305    ) -> Result<()> {
2306        let int_type = self.int_type;
2307        let res_index = match &translation.target {
2308            TranslationTo::Contiguous { start, end: _ } => {
2309                let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2310                self.builder.build_int_add(start_const, store_index, name)?
2311            }
2312            TranslationTo::Sparse { indices: _ } => {
2313                // load store index from layout
2314                let translate_index = self
2315                    .layout
2316                    .get_translation_index(elmt.expr_layout(), elmt.layout())
2317                    .unwrap();
2318                let translate_store_index =
2319                    translate_index + translation.get_to_index_in_data_layout();
2320                let translate_store_index =
2321                    int_type.const_int(translate_store_index.try_into().unwrap(), false);
2322                let elmt_index_strided = store_index;
2323                let curr_index =
2324                    self.builder
2325                        .build_int_add(elmt_index_strided, translate_store_index, name)?;
2326                let ptr = Self::get_ptr_to_index(
2327                    &self.builder,
2328                    self.int_type,
2329                    self.get_param("indices"),
2330                    curr_index,
2331                    name,
2332                );
2333                self.build_load(self.int_type, ptr, name)?.into_int_value()
2334            }
2335        };
2336
2337        let resi_ptr = Self::get_ptr_to_index(
2338            &self.builder,
2339            self.real_type,
2340            &self.tensor_ptr(),
2341            res_index,
2342            name,
2343        );
2344        self.builder.build_store(resi_ptr, float_value)?;
2345        Ok(())
2346    }
2347
2348    fn jit_compile_expr(
2349        &mut self,
2350        name: &str,
2351        expr: &Ast,
2352        index: &[IntValue<'ctx>],
2353        elmt: &TensorBlock,
2354        expr_index: Option<IntValue<'ctx>>,
2355    ) -> Result<FloatValue<'ctx>> {
2356        let name = elmt.name().unwrap_or(name);
2357        match &expr.kind {
2358            AstKind::Binop(binop) => {
2359                let lhs =
2360                    self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2361                let rhs =
2362                    self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2363                match binop.op {
2364                    '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2365                    '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2366                    '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2367                    '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2368                    unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2369                }
2370            }
2371            AstKind::Monop(monop) => {
2372                let child =
2373                    self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2374                match monop.op {
2375                    '-' => Ok(self.builder.build_float_neg(child, name)?),
2376                    unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2377                }
2378            }
2379            AstKind::Call(call) => match self.get_function(call.fn_name) {
2380                Some(function) => {
2381                    let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2382                    for arg in call.args.iter() {
2383                        let arg_val =
2384                            self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2385                        args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2386                    }
2387                    let ret_value = self
2388                        .builder
2389                        .build_call(function, args.as_slice(), name)?
2390                        .try_as_basic_value()
2391                        .left()
2392                        .unwrap()
2393                        .into_float_value();
2394                    Ok(ret_value)
2395                }
2396                None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2397            },
2398            AstKind::CallArg(arg) => {
2399                self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2400            }
2401            AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2402            AstKind::Name(iname) => {
2403                let ptr = self.get_param(iname.name);
2404                let layout = self.layout.get_layout(iname.name).unwrap();
2405                let iname_elmt_index = if layout.is_dense() {
2406                    // permute indices based on the index chars of this tensor
2407                    let mut no_transform = true;
2408                    let mut iname_index = Vec::new();
2409                    for (i, c) in iname.indices.iter().enumerate() {
2410                        // find the position index of this index char in the tensor's index chars,
2411                        // if it's not found then it must be a contraction index so is at the end
2412                        let pi = elmt
2413                            .indices()
2414                            .iter()
2415                            .position(|x| x == c)
2416                            .unwrap_or(elmt.indices().len());
2417                        iname_index.push(index[pi]);
2418                        no_transform = no_transform && pi == i;
2419                    }
2420                    // calculate the element index using iname_index and the shape of the tensor
2421                    // TODO: can we optimise this by using expr_index, and also including elmt_index?
2422                    if !iname_index.is_empty() {
2423                        let mut iname_elmt_index = *iname_index.last().unwrap();
2424                        let mut stride = 1u64;
2425                        for i in (0..iname_index.len() - 1).rev() {
2426                            let iname_i = iname_index[i];
2427                            let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2428                            stride *= shapei;
2429                            let stride_intval = self.context.i32_type().const_int(stride, false);
2430                            let stride_mul_i =
2431                                self.builder.build_int_mul(stride_intval, iname_i, name)?;
2432                            iname_elmt_index =
2433                                self.builder
2434                                    .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2435                        }
2436                        Some(iname_elmt_index)
2437                    } else {
2438                        let zero = self.context.i32_type().const_int(0, false);
2439                        Some(zero)
2440                    }
2441                } else if layout.is_sparse() || layout.is_diagonal() {
2442                    // must have come from jit_compile_sparse_block, so we can just use the elmt_index
2443                    // must have come from jit_compile_diagonal_block, so we can just use the elmt_index
2444                    expr_index
2445                } else {
2446                    panic!("unexpected layout");
2447                };
2448                let value_ptr = match iname_elmt_index {
2449                    Some(index) => {
2450                        Self::get_ptr_to_index(&self.builder, self.real_type, ptr, index, name)
2451                    }
2452                    None => *ptr,
2453                };
2454                Ok(self
2455                    .build_load(self.real_type, value_ptr, name)?
2456                    .into_float_value())
2457            }
2458            AstKind::NamedGradient(name) => {
2459                let name_str = name.to_string();
2460                let ptr = self.get_param(name_str.as_str());
2461                Ok(self
2462                    .build_load(self.real_type, *ptr, name_str.as_str())?
2463                    .into_float_value())
2464            }
2465            AstKind::Index(_) => todo!(),
2466            AstKind::Slice(_) => todo!(),
2467            AstKind::Integer(_) => todo!(),
2468            _ => panic!("unexprected astkind"),
2469        }
2470    }
2471
2472    fn clear(&mut self) {
2473        self.variables.clear();
2474        //self.functions.clear();
2475        self.fn_value_opt = None;
2476        self.tensor_ptr_opt = None;
2477    }
2478
2479    fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2480        match arg {
2481            BasicValueEnum::PointerValue(v) => v,
2482            BasicValueEnum::FloatValue(v) => {
2483                let alloca = self
2484                    .create_entry_block_builder()
2485                    .build_alloca(arg.get_type(), name)
2486                    .unwrap();
2487                self.builder.build_store(alloca, v).unwrap();
2488                alloca
2489            }
2490            BasicValueEnum::IntValue(v) => {
2491                let alloca = self
2492                    .create_entry_block_builder()
2493                    .build_alloca(arg.get_type(), name)
2494                    .unwrap();
2495                self.builder.build_store(alloca, v).unwrap();
2496                alloca
2497            }
2498            _ => unreachable!(),
2499        }
2500    }
2501
2502    pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2503        self.clear();
2504        let void_type = self.context.void_type();
2505        let fn_type = void_type.fn_type(
2506            &[
2507                self.real_ptr_type.into(),
2508                self.real_ptr_type.into(),
2509                self.int_type.into(),
2510                self.int_type.into(),
2511            ],
2512            false,
2513        );
2514        let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2515        let function = self.module.add_function("set_u0", fn_type, None);
2516
2517        // add noalias
2518        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2519        let noalign = self.context.create_enum_attribute(alias_id, 0);
2520        for i in &[0, 1] {
2521            function.add_attribute(AttributeLoc::Param(*i), noalign);
2522        }
2523
2524        let basic_block = self.context.append_basic_block(function, "entry");
2525        self.fn_value_opt = Some(function);
2526        self.builder.position_at_end(basic_block);
2527
2528        for (i, arg) in function.get_param_iter().enumerate() {
2529            let name = fn_arg_names[i];
2530            let alloca = self.function_arg_alloca(name, arg);
2531            self.insert_param(name, alloca);
2532        }
2533
2534        self.insert_data(model);
2535        self.insert_indices();
2536
2537        let mut nbarriers = 0;
2538        let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2539        #[allow(clippy::explicit_counter_loop)]
2540        for a in model.input_dep_defns() {
2541            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2542            self.jit_compile_call_barrier(nbarriers, total_barriers);
2543            nbarriers += 1;
2544        }
2545
2546        self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?;
2547        self.jit_compile_call_barrier(nbarriers, total_barriers);
2548
2549        self.builder.build_return(None)?;
2550
2551        if function.verify(true) {
2552            Ok(function)
2553        } else {
2554            function.print_to_stderr();
2555            unsafe {
2556                function.delete();
2557            }
2558            Err(anyhow!("Invalid generated function."))
2559        }
2560    }
2561
2562    pub fn compile_calc_out<'m>(
2563        &mut self,
2564        model: &'m DiscreteModel,
2565        include_constants: bool,
2566    ) -> Result<FunctionValue<'ctx>> {
2567        self.clear();
2568        let void_type = self.context.void_type();
2569        let fn_type = void_type.fn_type(
2570            &[
2571                self.real_type.into(),
2572                self.real_ptr_type.into(),
2573                self.real_ptr_type.into(),
2574                self.real_ptr_type.into(),
2575                self.int_type.into(),
2576                self.int_type.into(),
2577            ],
2578            false,
2579        );
2580        let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2581        let function_name = if include_constants {
2582            "calc_out_full"
2583        } else {
2584            "calc_out"
2585        };
2586        let function = self.module.add_function(function_name, fn_type, None);
2587
2588        // add noalias
2589        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2590        let noalign = self.context.create_enum_attribute(alias_id, 0);
2591        for i in &[1, 2] {
2592            function.add_attribute(AttributeLoc::Param(*i), noalign);
2593        }
2594
2595        let basic_block = self.context.append_basic_block(function, "entry");
2596        self.fn_value_opt = Some(function);
2597        self.builder.position_at_end(basic_block);
2598
2599        for (i, arg) in function.get_param_iter().enumerate() {
2600            let name = fn_arg_names[i];
2601            let alloca = self.function_arg_alloca(name, arg);
2602            self.insert_param(name, alloca);
2603        }
2604
2605        self.insert_state(model.state());
2606        self.insert_data(model);
2607        self.insert_indices();
2608
2609        // print thread_id and thread_dim
2610        //let thread_id = function.get_nth_param(3).unwrap();
2611        //let thread_dim = function.get_nth_param(4).unwrap();
2612        //self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?;
2613        //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?;
2614        if let Some(out) = model.out() {
2615            let mut nbarriers = 0;
2616            let mut total_barriers =
2617                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2618            if include_constants {
2619                total_barriers += model.input_dep_defns().len() as u64;
2620                // calculate time independant definitions
2621                for tensor in model.input_dep_defns() {
2622                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2623                    self.jit_compile_call_barrier(nbarriers, total_barriers);
2624                    nbarriers += 1;
2625                }
2626            }
2627
2628            // calculate time dependant definitions
2629            for tensor in model.time_dep_defns() {
2630                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2631                self.jit_compile_call_barrier(nbarriers, total_barriers);
2632                nbarriers += 1;
2633            }
2634
2635            // calculate state dependant definitions
2636            #[allow(clippy::explicit_counter_loop)]
2637            for a in model.state_dep_defns() {
2638                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2639                self.jit_compile_call_barrier(nbarriers, total_barriers);
2640                nbarriers += 1;
2641            }
2642
2643            self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())))?;
2644            self.jit_compile_call_barrier(nbarriers, total_barriers);
2645        }
2646        self.builder.build_return(None)?;
2647
2648        if function.verify(true) {
2649            Ok(function)
2650        } else {
2651            function.print_to_stderr();
2652            unsafe {
2653                function.delete();
2654            }
2655            Err(anyhow!("Invalid generated function."))
2656        }
2657    }
2658
2659    pub fn compile_calc_stop<'m>(
2660        &mut self,
2661        model: &'m DiscreteModel,
2662    ) -> Result<FunctionValue<'ctx>> {
2663        self.clear();
2664        let void_type = self.context.void_type();
2665        let fn_type = void_type.fn_type(
2666            &[
2667                self.real_type.into(),
2668                self.real_ptr_type.into(),
2669                self.real_ptr_type.into(),
2670                self.real_ptr_type.into(),
2671                self.int_type.into(),
2672                self.int_type.into(),
2673            ],
2674            false,
2675        );
2676        let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2677        let function = self.module.add_function("calc_stop", fn_type, None);
2678
2679        // add noalias
2680        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2681        let noalign = self.context.create_enum_attribute(alias_id, 0);
2682        for i in &[1, 2, 3] {
2683            function.add_attribute(AttributeLoc::Param(*i), noalign);
2684        }
2685
2686        let basic_block = self.context.append_basic_block(function, "entry");
2687        self.fn_value_opt = Some(function);
2688        self.builder.position_at_end(basic_block);
2689
2690        for (i, arg) in function.get_param_iter().enumerate() {
2691            let name = fn_arg_names[i];
2692            let alloca = self.function_arg_alloca(name, arg);
2693            self.insert_param(name, alloca);
2694        }
2695
2696        self.insert_state(model.state());
2697        self.insert_data(model);
2698        self.insert_indices();
2699
2700        if let Some(stop) = model.stop() {
2701            // calculate time dependant definitions
2702            let mut nbarriers = 0;
2703            let total_barriers =
2704                (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2705            for tensor in model.time_dep_defns() {
2706                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2707                self.jit_compile_call_barrier(nbarriers, total_barriers);
2708                nbarriers += 1;
2709            }
2710
2711            // calculate state dependant definitions
2712            for a in model.state_dep_defns() {
2713                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2714                self.jit_compile_call_barrier(nbarriers, total_barriers);
2715                nbarriers += 1;
2716            }
2717
2718            let res_ptr = self.get_param("root");
2719            self.jit_compile_tensor(stop, Some(*res_ptr))?;
2720            self.jit_compile_call_barrier(nbarriers, total_barriers);
2721        }
2722        self.builder.build_return(None)?;
2723
2724        if function.verify(true) {
2725            Ok(function)
2726        } else {
2727            function.print_to_stderr();
2728            unsafe {
2729                function.delete();
2730            }
2731            Err(anyhow!("Invalid generated function."))
2732        }
2733    }
2734
2735    pub fn compile_rhs<'m>(
2736        &mut self,
2737        model: &'m DiscreteModel,
2738        include_constants: bool,
2739    ) -> Result<FunctionValue<'ctx>> {
2740        self.clear();
2741        let void_type = self.context.void_type();
2742        let fn_type = void_type.fn_type(
2743            &[
2744                self.real_type.into(),
2745                self.real_ptr_type.into(),
2746                self.real_ptr_type.into(),
2747                self.real_ptr_type.into(),
2748                self.int_type.into(),
2749                self.int_type.into(),
2750            ],
2751            false,
2752        );
2753        let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
2754        let function_name = if include_constants { "rhs_full" } else { "rhs" };
2755        let function = self.module.add_function(function_name, fn_type, None);
2756
2757        // add noalias
2758        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2759        let noalign = self.context.create_enum_attribute(alias_id, 0);
2760        for i in &[1, 2, 3] {
2761            function.add_attribute(AttributeLoc::Param(*i), noalign);
2762        }
2763
2764        let basic_block = self.context.append_basic_block(function, "entry");
2765        self.fn_value_opt = Some(function);
2766        self.builder.position_at_end(basic_block);
2767
2768        for (i, arg) in function.get_param_iter().enumerate() {
2769            let name = fn_arg_names[i];
2770            let alloca = self.function_arg_alloca(name, arg);
2771            self.insert_param(name, alloca);
2772        }
2773
2774        self.insert_state(model.state());
2775        self.insert_data(model);
2776        self.insert_indices();
2777
2778        let mut nbarriers = 0;
2779        let mut total_barriers =
2780            (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2781        if include_constants {
2782            total_barriers += model.input_dep_defns().len() as u64;
2783            // calculate constant definitions
2784            for tensor in model.input_dep_defns() {
2785                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2786                self.jit_compile_call_barrier(nbarriers, total_barriers);
2787                nbarriers += 1;
2788            }
2789        }
2790
2791        // calculate time dependant definitions
2792        for tensor in model.time_dep_defns() {
2793            self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2794            self.jit_compile_call_barrier(nbarriers, total_barriers);
2795            nbarriers += 1;
2796        }
2797
2798        // TODO: could split state dep defns into before and after F
2799        for a in model.state_dep_defns() {
2800            self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2801            self.jit_compile_call_barrier(nbarriers, total_barriers);
2802            nbarriers += 1;
2803        }
2804
2805        // F
2806        let res_ptr = self.get_param("rr");
2807        self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?;
2808        self.jit_compile_call_barrier(nbarriers, total_barriers);
2809
2810        self.builder.build_return(None)?;
2811
2812        if function.verify(true) {
2813            Ok(function)
2814        } else {
2815            function.print_to_stderr();
2816            unsafe {
2817                function.delete();
2818            }
2819            Err(anyhow!("Invalid generated function."))
2820        }
2821    }
2822
2823    pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2824        self.clear();
2825        let void_type = self.context.void_type();
2826        let fn_type = void_type.fn_type(
2827            &[
2828                self.real_type.into(),
2829                self.real_ptr_type.into(),
2830                self.real_ptr_type.into(),
2831                self.real_ptr_type.into(),
2832                self.int_type.into(),
2833                self.int_type.into(),
2834            ],
2835            false,
2836        );
2837        let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
2838        let function = self.module.add_function("mass", fn_type, None);
2839
2840        // add noalias
2841        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2842        let noalign = self.context.create_enum_attribute(alias_id, 0);
2843        for i in &[1, 2, 3] {
2844            function.add_attribute(AttributeLoc::Param(*i), noalign);
2845        }
2846
2847        let basic_block = self.context.append_basic_block(function, "entry");
2848        self.fn_value_opt = Some(function);
2849        self.builder.position_at_end(basic_block);
2850
2851        for (i, arg) in function.get_param_iter().enumerate() {
2852            let name = fn_arg_names[i];
2853            let alloca = self.function_arg_alloca(name, arg);
2854            self.insert_param(name, alloca);
2855        }
2856
2857        // only put code in this function if we have a state_dot and lhs
2858        if model.state_dot().is_some() && model.lhs().is_some() {
2859            let state_dot = model.state_dot().unwrap();
2860            let lhs = model.lhs().unwrap();
2861
2862            self.insert_dot_state(state_dot);
2863            self.insert_data(model);
2864            self.insert_indices();
2865
2866            // calculate time dependant definitions
2867            let mut nbarriers = 0;
2868            let total_barriers =
2869                (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
2870            for tensor in model.time_dep_defns() {
2871                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2872                self.jit_compile_call_barrier(nbarriers, total_barriers);
2873                nbarriers += 1;
2874            }
2875
2876            for a in model.dstate_dep_defns() {
2877                self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2878                self.jit_compile_call_barrier(nbarriers, total_barriers);
2879                nbarriers += 1;
2880            }
2881
2882            // mass
2883            let res_ptr = self.get_param("rr");
2884            self.jit_compile_tensor(lhs, Some(*res_ptr))?;
2885            self.jit_compile_call_barrier(nbarriers, total_barriers);
2886        }
2887
2888        self.builder.build_return(None)?;
2889
2890        if function.verify(true) {
2891            Ok(function)
2892        } else {
2893            function.print_to_stderr();
2894            unsafe {
2895                function.delete();
2896            }
2897            Err(anyhow!("Invalid generated function."))
2898        }
2899    }
2900
2901    pub fn compile_gradient(
2902        &mut self,
2903        original_function: FunctionValue<'ctx>,
2904        args_type: &[CompileGradientArgType],
2905        mode: CompileMode,
2906        fn_name: &str,
2907    ) -> Result<FunctionValue<'ctx>> {
2908        self.clear();
2909
2910        // construct the gradient function
2911        let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
2912
2913        let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
2914
2915        let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
2916        let mut start_param_index: Vec<u32> = Vec::new();
2917        let mut ptr_arg_indices: Vec<u32> = Vec::new();
2918        for (i, arg) in original_function.get_param_iter().enumerate() {
2919            start_param_index.push(u32::try_from(fn_type.len()).unwrap());
2920            let arg_type = arg.get_type();
2921            fn_type.push(arg_type.into());
2922
2923            // constant args with type T in the original funciton have 2 args of type [int, T]
2924            enzyme_fn_type.push(self.int_type.into());
2925            enzyme_fn_type.push(arg.get_type().into());
2926
2927            if arg_type.is_pointer_type() {
2928                ptr_arg_indices.push(u32::try_from(i).unwrap());
2929            }
2930
2931            match args_type[i] {
2932                CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
2933                    fn_type.push(arg.get_type().into());
2934                    enzyme_fn_type.push(arg.get_type().into());
2935                }
2936                CompileGradientArgType::Const => {}
2937            }
2938        }
2939        let void_type = self.context.void_type();
2940        let fn_type = void_type.fn_type(fn_type.as_slice(), false);
2941        let function = self.module.add_function(fn_name, fn_type, None);
2942
2943        // add noalias
2944        let alias_id = Attribute::get_named_enum_kind_id("noalias");
2945        let noalign = self.context.create_enum_attribute(alias_id, 0);
2946        for i in ptr_arg_indices {
2947            function.add_attribute(AttributeLoc::Param(i), noalign);
2948        }
2949
2950        let basic_block = self.context.append_basic_block(function, "entry");
2951        self.fn_value_opt = Some(function);
2952        self.builder.position_at_end(basic_block);
2953
2954        let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
2955        let mut input_activity = Vec::new();
2956        let mut arg_trees = Vec::new();
2957        for (i, arg) in original_function.get_param_iter().enumerate() {
2958            let param_index = start_param_index[i];
2959            let fn_arg = function.get_nth_param(param_index).unwrap();
2960
2961            // we'll probably only get double or pointers to doubles, so let assume this for now
2962            // todo: perhaps refactor this into a recursive function, might be overkill
2963            let concrete_type = match arg.get_type() {
2964                BasicTypeEnum::PointerType(_) => CConcreteType_DT_Pointer,
2965                BasicTypeEnum::FloatType(t) => {
2966                    if t == self.context.f64_type() {
2967                        CConcreteType_DT_Double
2968                    } else {
2969                        panic!("unsupported type")
2970                    }
2971                }
2972                BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
2973                _ => panic!("unsupported type"),
2974            };
2975            let new_tree = unsafe {
2976                EnzymeNewTypeTreeCT(
2977                    concrete_type,
2978                    self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
2979                )
2980            };
2981            unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
2982
2983            // pointer to double
2984            if concrete_type == CConcreteType_DT_Pointer {
2985                // assume the pointer is to a double
2986                let inner_concrete_type = CConcreteType_DT_Double;
2987                let inner_new_tree = unsafe {
2988                    EnzymeNewTypeTreeCT(
2989                        inner_concrete_type,
2990                        self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
2991                    )
2992                };
2993                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
2994                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
2995                unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
2996            }
2997
2998            arg_trees.push(new_tree);
2999            match args_type[i] {
3000                CompileGradientArgType::Dup => {
3001                    // pass in the arg value
3002                    enzyme_fn_args.push(fn_arg.into());
3003
3004                    // pass in the darg value
3005                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3006                    enzyme_fn_args.push(fn_darg.into());
3007
3008                    input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3009                }
3010                CompileGradientArgType::DupNoNeed => {
3011                    // pass in the arg value
3012                    enzyme_fn_args.push(fn_arg.into());
3013
3014                    // pass in the darg value
3015                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3016                    enzyme_fn_args.push(fn_darg.into());
3017
3018                    input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3019                }
3020                CompileGradientArgType::Const => {
3021                    // pass in the arg value
3022                    enzyme_fn_args.push(fn_arg.into());
3023
3024                    input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3025                }
3026            }
3027        }
3028        // if we have void ret, this must be false;
3029        let ret_primary_ret = false;
3030        let diff_ret = false;
3031        let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3032        let ret_tree = unsafe {
3033            EnzymeNewTypeTreeCT(
3034                CConcreteType_DT_Anything,
3035                self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3036            )
3037        };
3038
3039        // always optimize
3040        let fnc_opt_base = true;
3041        let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3042
3043        let kv_tmp = IntList {
3044            data: std::ptr::null_mut(),
3045            size: 0,
3046        };
3047        let mut known_values = vec![kv_tmp; input_activity.len()];
3048
3049        let fn_type_info = CFnTypeInfo {
3050            Arguments: arg_trees.as_mut_ptr(),
3051            Return: ret_tree,
3052            KnownValues: known_values.as_mut_ptr(),
3053        };
3054
3055        let type_analysis: EnzymeTypeAnalysisRef =
3056            unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3057
3058        let mut args_uncacheable = vec![0; arg_trees.len()];
3059
3060        let enzyme_function = match mode {
3061            CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3062                EnzymeCreateForwardDiff(
3063                    logic_ref, // Logic
3064                    std::ptr::null_mut(),
3065                    std::ptr::null_mut(),
3066                    original_function.as_value_ref(),
3067                    ret_activity, // LLVM function, return type
3068                    input_activity.as_mut_ptr(),
3069                    input_activity.len(), // constant arguments
3070                    type_analysis,        // type analysis struct
3071                    ret_primary_ret as u8,
3072                    CDerivativeMode_DEM_ForwardMode, // return value, dret_used, top_level which was 1
3073                    1,                               // free memory
3074                    0,                               // runtime activity
3075                    1,                               // vector mode width
3076                    std::ptr::null_mut(),
3077                    fn_type_info, // additional_arg, type info (return + args)
3078                    args_uncacheable.as_mut_ptr(),
3079                    args_uncacheable.len(), // uncacheable arguments
3080                    std::ptr::null_mut(),   // write augmented function to this
3081                )
3082            },
3083            CompileMode::Reverse | CompileMode::ReverseSens => {
3084                let mut call_enzyme = || unsafe {
3085                    EnzymeCreatePrimalAndGradient(
3086                        logic_ref,
3087                        std::ptr::null_mut(),
3088                        std::ptr::null_mut(),
3089                        original_function.as_value_ref(),
3090                        ret_activity,
3091                        input_activity.as_mut_ptr(),
3092                        input_activity.len(),
3093                        type_analysis,
3094                        ret_primary_ret as u8,
3095                        diff_ret as u8,
3096                        CDerivativeMode_DEM_ReverseModeCombined,
3097                        0,
3098                        1,
3099                        1,
3100                        std::ptr::null_mut(),
3101                        0,
3102                        fn_type_info,
3103                        args_uncacheable.as_mut_ptr(),
3104                        args_uncacheable.len(),
3105                        std::ptr::null_mut(),
3106                        if self.threaded { 1 } else { 0 },
3107                    )
3108                };
3109                if self.threaded {
3110                    // the register call handler alters a global variable, so we need to lock it
3111                    let _lock = my_mutex.lock().unwrap();
3112                    let barrier_string = CString::new("barrier").unwrap();
3113                    unsafe {
3114                        EnzymeRegisterCallHandler(
3115                            barrier_string.as_ptr(),
3116                            Some(fwd_handler),
3117                            Some(rev_handler),
3118                        )
3119                    };
3120                    let ret = call_enzyme();
3121                    // unregister it so some other thread doesn't use it
3122                    unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3123                    ret
3124                } else {
3125                    call_enzyme()
3126                }
3127            }
3128        };
3129
3130        // free everything
3131        unsafe { FreeEnzymeLogic(logic_ref) };
3132        unsafe { FreeTypeAnalysis(type_analysis) };
3133        unsafe { EnzymeFreeTypeTree(ret_tree) };
3134        for tree in arg_trees {
3135            unsafe { EnzymeFreeTypeTree(tree) };
3136        }
3137
3138        // call enzyme function
3139        let enzyme_function =
3140            unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3141        self.builder
3142            .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3143
3144        // return
3145        self.builder.build_return(None)?;
3146
3147        if function.verify(true) {
3148            Ok(function)
3149        } else {
3150            function.print_to_stderr();
3151            enzyme_function.print_to_stderr();
3152            unsafe {
3153                function.delete();
3154            }
3155            Err(anyhow!("Invalid generated function."))
3156        }
3157    }
3158
3159    pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3160        self.clear();
3161        let fn_type = self.context.void_type().fn_type(
3162            &[
3163                self.int_ptr_type.into(),
3164                self.int_ptr_type.into(),
3165                self.int_ptr_type.into(),
3166                self.int_ptr_type.into(),
3167                self.int_ptr_type.into(),
3168                self.int_ptr_type.into(),
3169            ],
3170            false,
3171        );
3172
3173        let function = self.module.add_function("get_dims", fn_type, None);
3174        let block = self.context.append_basic_block(function, "entry");
3175        let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3176        self.builder.position_at_end(block);
3177
3178        for (i, arg) in function.get_param_iter().enumerate() {
3179            let name = fn_arg_names[i];
3180            let alloca = self.function_arg_alloca(name, arg);
3181            self.insert_param(name, alloca);
3182        }
3183
3184        self.insert_indices();
3185
3186        let number_of_states = model.state().nnz() as u64;
3187        let number_of_inputs = model.inputs().iter().fold(0, |acc, x| acc + x.nnz()) as u64;
3188        let number_of_outputs = match model.out() {
3189            Some(out) => out.nnz() as u64,
3190            None => 0,
3191        };
3192        let number_of_stop = if let Some(stop) = model.stop() {
3193            stop.nnz() as u64
3194        } else {
3195            0
3196        };
3197        let has_mass = match model.lhs().is_some() {
3198            true => 1u64,
3199            false => 0u64,
3200        };
3201        let data_len = self.layout.data().len() as u64;
3202        self.builder.build_store(
3203            *self.get_param("states"),
3204            self.int_type.const_int(number_of_states, false),
3205        )?;
3206        self.builder.build_store(
3207            *self.get_param("inputs"),
3208            self.int_type.const_int(number_of_inputs, false),
3209        )?;
3210        self.builder.build_store(
3211            *self.get_param("outputs"),
3212            self.int_type.const_int(number_of_outputs, false),
3213        )?;
3214        self.builder.build_store(
3215            *self.get_param("data"),
3216            self.int_type.const_int(data_len, false),
3217        )?;
3218        self.builder.build_store(
3219            *self.get_param("stop"),
3220            self.int_type.const_int(number_of_stop, false),
3221        )?;
3222        self.builder.build_store(
3223            *self.get_param("has_mass"),
3224            self.int_type.const_int(has_mass, false),
3225        )?;
3226        self.builder.build_return(None)?;
3227
3228        if function.verify(true) {
3229            Ok(function)
3230        } else {
3231            function.print_to_stderr();
3232            unsafe {
3233                function.delete();
3234            }
3235            Err(anyhow!("Invalid generated function."))
3236        }
3237    }
3238
3239    pub fn compile_get_tensor(
3240        &mut self,
3241        model: &DiscreteModel,
3242        name: &str,
3243    ) -> Result<FunctionValue<'ctx>> {
3244        self.clear();
3245        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3246        let fn_type = self.context.void_type().fn_type(
3247            &[
3248                self.real_ptr_type.into(),
3249                real_ptr_ptr_type.into(),
3250                self.int_ptr_type.into(),
3251            ],
3252            false,
3253        );
3254        let function_name = format!("get_tensor_{name}");
3255        let function = self
3256            .module
3257            .add_function(function_name.as_str(), fn_type, None);
3258        let basic_block = self.context.append_basic_block(function, "entry");
3259        self.fn_value_opt = Some(function);
3260
3261        let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3262        self.builder.position_at_end(basic_block);
3263
3264        for (i, arg) in function.get_param_iter().enumerate() {
3265            let name = fn_arg_names[i];
3266            let alloca = self.function_arg_alloca(name, arg);
3267            self.insert_param(name, alloca);
3268        }
3269
3270        self.insert_data(model);
3271        let ptr = self.get_param(name);
3272        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3273        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3274        self.builder
3275            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3276        self.builder
3277            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3278        self.builder.build_return(None)?;
3279
3280        if function.verify(true) {
3281            Ok(function)
3282        } else {
3283            function.print_to_stderr();
3284            unsafe {
3285                function.delete();
3286            }
3287            Err(anyhow!("Invalid generated function."))
3288        }
3289    }
3290
3291    pub fn compile_get_constant(
3292        &mut self,
3293        model: &DiscreteModel,
3294        name: &str,
3295    ) -> Result<FunctionValue<'ctx>> {
3296        self.clear();
3297        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3298        let fn_type = self
3299            .context
3300            .void_type()
3301            .fn_type(&[real_ptr_ptr_type.into(), self.int_ptr_type.into()], false);
3302        let function_name = format!("get_constant_{name}");
3303        let function = self
3304            .module
3305            .add_function(function_name.as_str(), fn_type, None);
3306        let basic_block = self.context.append_basic_block(function, "entry");
3307        self.fn_value_opt = Some(function);
3308
3309        let fn_arg_names = &["tensor_data", "tensor_size"];
3310        self.builder.position_at_end(basic_block);
3311
3312        for (i, arg) in function.get_param_iter().enumerate() {
3313            let name = fn_arg_names[i];
3314            let alloca = self.function_arg_alloca(name, arg);
3315            self.insert_param(name, alloca);
3316        }
3317
3318        self.insert_constants(model);
3319        let ptr = self.get_param(name);
3320        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3321        let tensor_size_value = self.int_type.const_int(tensor_size, false);
3322        self.builder
3323            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3324        self.builder
3325            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3326        self.builder.build_return(None)?;
3327
3328        if function.verify(true) {
3329            Ok(function)
3330        } else {
3331            function.print_to_stderr();
3332            unsafe {
3333                function.delete();
3334            }
3335            Err(anyhow!("Invalid generated function."))
3336        }
3337    }
3338
3339    pub fn compile_inputs(
3340        &mut self,
3341        model: &DiscreteModel,
3342        is_get: bool,
3343    ) -> Result<FunctionValue<'ctx>> {
3344        self.clear();
3345        let void_type = self.context.void_type();
3346        let fn_type = void_type.fn_type(
3347            &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3348            false,
3349        );
3350        let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3351        let function = self.module.add_function(function_name, fn_type, None);
3352        let mut block = self.context.append_basic_block(function, "entry");
3353        self.fn_value_opt = Some(function);
3354
3355        let fn_arg_names = &["inputs", "data"];
3356        self.builder.position_at_end(block);
3357
3358        for (i, arg) in function.get_param_iter().enumerate() {
3359            let name = fn_arg_names[i];
3360            let alloca = self.function_arg_alloca(name, arg);
3361            self.insert_param(name, alloca);
3362        }
3363
3364        let mut inputs_index = 0usize;
3365        for input in model.inputs() {
3366            let name = format!("input_{}", input.name());
3367            self.insert_tensor(input, false);
3368            let ptr = self.get_var(input);
3369            // loop thru the elements of this input and set/get them using the inputs ptr
3370            let inputs_start_index = self.int_type.const_int(inputs_index as u64, false);
3371            let start_index = self.int_type.const_int(0, false);
3372            let end_index = self
3373                .int_type
3374                .const_int(input.nnz().try_into().unwrap(), false);
3375
3376            let input_block = self.context.append_basic_block(function, name.as_str());
3377            self.builder.build_unconditional_branch(input_block)?;
3378            self.builder.position_at_end(input_block);
3379            let index = self.builder.build_phi(self.int_type, "i")?;
3380            index.add_incoming(&[(&start_index, block)]);
3381
3382            // loop body - copy value from inputs to data
3383            let curr_input_index = index.as_basic_value().into_int_value();
3384            let input_ptr = Self::get_ptr_to_index(
3385                &self.builder,
3386                self.real_type,
3387                ptr,
3388                curr_input_index,
3389                name.as_str(),
3390            );
3391            let curr_inputs_index =
3392                self.builder
3393                    .build_int_add(inputs_start_index, curr_input_index, name.as_str())?;
3394            let inputs_ptr = Self::get_ptr_to_index(
3395                &self.builder,
3396                self.real_type,
3397                self.get_param("inputs"),
3398                curr_inputs_index,
3399                name.as_str(),
3400            );
3401            if is_get {
3402                let input_value = self
3403                    .build_load(self.real_type, input_ptr, name.as_str())?
3404                    .into_float_value();
3405                self.builder.build_store(inputs_ptr, input_value)?;
3406            } else {
3407                let input_value = self
3408                    .build_load(self.real_type, inputs_ptr, name.as_str())?
3409                    .into_float_value();
3410                self.builder.build_store(input_ptr, input_value)?;
3411            }
3412
3413            // increment loop index
3414            let one = self.int_type.const_int(1, false);
3415            let next_index = self
3416                .builder
3417                .build_int_add(curr_input_index, one, name.as_str())?;
3418            index.add_incoming(&[(&next_index, input_block)]);
3419
3420            // loop condition
3421            let loop_while = self.builder.build_int_compare(
3422                IntPredicate::ULT,
3423                next_index,
3424                end_index,
3425                name.as_str(),
3426            )?;
3427            let post_block = self.context.append_basic_block(function, name.as_str());
3428            self.builder
3429                .build_conditional_branch(loop_while, input_block, post_block)?;
3430            self.builder.position_at_end(post_block);
3431
3432            // get ready for next input
3433            block = post_block;
3434            inputs_index += input.nnz();
3435        }
3436        self.builder.build_return(None)?;
3437
3438        if function.verify(true) {
3439            Ok(function)
3440        } else {
3441            function.print_to_stderr();
3442            unsafe {
3443                function.delete();
3444            }
3445            Err(anyhow!("Invalid generated function."))
3446        }
3447    }
3448
3449    pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3450        self.clear();
3451        let void_type = self.context.void_type();
3452        let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false);
3453        let function = self.module.add_function("set_id", fn_type, None);
3454        let mut block = self.context.append_basic_block(function, "entry");
3455
3456        let fn_arg_names = &["id"];
3457        self.builder.position_at_end(block);
3458
3459        for (i, arg) in function.get_param_iter().enumerate() {
3460            let name = fn_arg_names[i];
3461            let alloca = self.function_arg_alloca(name, arg);
3462            self.insert_param(name, alloca);
3463        }
3464
3465        let mut id_index = 0usize;
3466        for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3467            let name = blk.name().unwrap_or("unknown");
3468            // loop thru the elements of this state blk and set the corresponding elements of id
3469            let id_start_index = self.int_type.const_int(id_index as u64, false);
3470            let blk_start_index = self.int_type.const_int(0, false);
3471            let blk_end_index = self
3472                .int_type
3473                .const_int(blk.nnz().try_into().unwrap(), false);
3474
3475            let blk_block = self.context.append_basic_block(function, name);
3476            self.builder.build_unconditional_branch(blk_block)?;
3477            self.builder.position_at_end(blk_block);
3478            let index = self.builder.build_phi(self.int_type, "i")?;
3479            index.add_incoming(&[(&blk_start_index, block)]);
3480
3481            // loop body - copy value from inputs to data
3482            let curr_blk_index = index.as_basic_value().into_int_value();
3483            let curr_id_index = self
3484                .builder
3485                .build_int_add(id_start_index, curr_blk_index, name)?;
3486            let id_ptr = Self::get_ptr_to_index(
3487                &self.builder,
3488                self.real_type,
3489                self.get_param("id"),
3490                curr_id_index,
3491                name,
3492            );
3493            let is_algebraic_float = if *is_algebraic {
3494                0.0 as RealType
3495            } else {
3496                1.0 as RealType
3497            };
3498            let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3499            self.builder.build_store(id_ptr, is_algebraic_value)?;
3500
3501            // increment loop index
3502            let one = self.int_type.const_int(1, false);
3503            let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3504            index.add_incoming(&[(&next_index, blk_block)]);
3505
3506            // loop condition
3507            let loop_while = self.builder.build_int_compare(
3508                IntPredicate::ULT,
3509                next_index,
3510                blk_end_index,
3511                name,
3512            )?;
3513            let post_block = self.context.append_basic_block(function, name);
3514            self.builder
3515                .build_conditional_branch(loop_while, blk_block, post_block)?;
3516            self.builder.position_at_end(post_block);
3517
3518            // get ready for next blk
3519            block = post_block;
3520            id_index += blk.nnz();
3521        }
3522        self.builder.build_return(None)?;
3523
3524        if function.verify(true) {
3525            Ok(function)
3526        } else {
3527            function.print_to_stderr();
3528            unsafe {
3529                function.delete();
3530            }
3531            Err(anyhow!("Invalid generated function."))
3532        }
3533    }
3534
3535    pub fn module(&self) -> &Module<'ctx> {
3536        &self.module
3537    }
3538}