Skip to main content

diffsl/execution/llvm/
codegen.rs

1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use core::panic;
4use inkwell::attributes::{Attribute, AttributeLoc};
5use inkwell::basic_block::BasicBlock;
6use inkwell::builder::Builder;
7use inkwell::context::{AsContextRef, Context};
8use inkwell::debug_info::AsDIScope;
9use inkwell::debug_info::{DICompileUnit, DIFlags, DIFlagsConstants, DebugInfoBuilder};
10use inkwell::execution_engine::ExecutionEngine;
11use inkwell::intrinsics::Intrinsic;
12use inkwell::module::{Linkage, Module};
13use inkwell::passes::PassBuilderOptions;
14use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
15use inkwell::types::{
16    BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
17};
18use inkwell::values::{
19    AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
20    FunctionValue, GlobalValue, IntValue, PointerValue,
21};
22use inkwell::{
23    AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
24    OptimizationLevel,
25};
26use llvm_sys::core::{
27    LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
28    LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
29};
30use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
31use pest::Span;
32use std::collections::HashMap;
33use std::ffi::CString;
34use std::iter::zip;
35use std::pin::Pin;
36use target_lexicon::Triple;
37
38use crate::ast::{Ast, AstKind};
39use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
40use crate::enzyme::{
41    CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
42    CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
43    CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
44    DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
45    EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
46    EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
47    FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
48    CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
49};
50use crate::execution::compiler::CompilerOptions;
51use crate::execution::module::{
52    CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
53};
54use crate::execution::scalar::RealType;
55use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
56use lazy_static::lazy_static;
57use std::sync::Mutex;
58
59lazy_static! {
60    static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
61}
62
63struct ImmovableLlvmModule {
64    // actually has lifetime of `context`
65    // declared first so it's droped before `context`
66    codegen: Option<CodeGen<'static>>,
67    // safety: we must never move out of this box as long as codgen is alive
68    context: AliasableBox<Context>,
69    _pin: std::marker::PhantomPinned,
70}
71
72pub struct LlvmModule {
73    inner: Pin<Box<ImmovableLlvmModule>>,
74    machine: TargetMachine,
75}
76
77unsafe impl Send for LlvmModule {}
78unsafe impl Sync for LlvmModule {}
79
80impl LlvmModule {
81    fn new(
82        triple: Option<Triple>,
83        model: &DiscreteModel,
84        threaded: bool,
85        real_type: RealType,
86        debug: bool,
87        constants_use_jit: bool,
88    ) -> Result<Self> {
89        let initialization_config = &InitializationConfig::default();
90        Target::initialize_all(initialization_config);
91        let host_triple = Triple::host();
92        let (triple_str, native) = match triple {
93            Some(ref triple) => (triple.to_string(), false),
94            None => (host_triple.to_string(), true),
95        };
96        let triple = TargetTriple::create(triple_str.as_str());
97        let target = Target::from_triple(&triple).unwrap();
98        let cpu = if native {
99            TargetMachine::get_host_cpu_name().to_string()
100        } else {
101            "generic".to_string()
102        };
103        let features = if native {
104            TargetMachine::get_host_cpu_features().to_string()
105        } else {
106            "".to_string()
107        };
108        let machine = target
109            .create_target_machine(
110                &triple,
111                cpu.as_str(),
112                features.as_str(),
113                inkwell::OptimizationLevel::Aggressive,
114                inkwell::targets::RelocMode::PIC,
115                inkwell::targets::CodeModel::Default,
116            )
117            .unwrap();
118
119        let context = AliasableBox::from_unique(Box::new(Context::create()));
120        let mut pinned = Self {
121            inner: Box::pin(ImmovableLlvmModule {
122                codegen: None,
123                context,
124                _pin: std::marker::PhantomPinned,
125            }),
126            machine,
127        };
128
129        let context_ref = pinned.inner.context.as_ref();
130        let real_type_llvm = match real_type {
131            RealType::F32 => context_ref.f32_type(),
132            RealType::F64 => context_ref.f64_type(),
133        };
134        let int_type_llvm = context_ref.i32_type();
135        let ptr_size_bits = pinned
136            .machine
137            .get_target_data()
138            .get_bit_size(&context_ref.ptr_type(AddressSpace::default()));
139        let real_size_bits = pinned
140            .machine
141            .get_target_data()
142            .get_bit_size(&real_type_llvm);
143        let int_size_bits = pinned
144            .machine
145            .get_target_data()
146            .get_bit_size(&int_type_llvm);
147        let codegen = CodeGen::new(
148            model,
149            context_ref,
150            real_type,
151            real_type_llvm,
152            int_type_llvm,
153            threaded,
154            ptr_size_bits,
155            real_size_bits,
156            int_size_bits,
157            debug,
158            constants_use_jit,
159        )?;
160        let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
161        unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
162        Ok(pinned)
163    }
164
165    fn pre_autodiff_optimisation(&mut self) -> Result<()> {
166        //let pass_manager_builder = PassManagerBuilder::create();
167        //pass_manager_builder.set_optimization_level(inkwell::OptimizationLevel::Default);
168        //let pass_manager = PassManager::create(());
169        //pass_manager_builder.populate_module_pass_manager(&pass_manager);
170        //pass_manager.run_on(self.codegen().module());
171
172        //self.codegen().module().print_to_stderr();
173        // optimise at -O2 no unrolling before giving to enzyme
174        let pass_options = PassBuilderOptions::create();
175        //pass_options.set_verify_each(true);
176        //pass_options.set_debug_logging(true);
177        //pass_options.set_loop_interleaving(true);
178        pass_options.set_loop_vectorization(false);
179        pass_options.set_loop_slp_vectorization(false);
180        pass_options.set_loop_unrolling(false);
181        //pass_options.set_forget_all_scev_in_loop_unroll(true);
182        //pass_options.set_licm_mssa_opt_cap(1);
183        //pass_options.set_licm_mssa_no_acc_for_promotion_cap(10);
184        //pass_options.set_call_graph_profile(true);
185        //pass_options.set_merge_functions(true);
186        //let path = "jit_module_before_pre_autodiff_opt.ll";
187        //self.codegen()
188        //    .module()
189        //    .print_to_file(path)
190        //    .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))?;
191
192        //let passes = "default<O2>";
193        let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
194        let (codegen, machine) = self.codegen_and_machine_mut();
195        codegen
196            .module()
197            .run_passes(passes, machine, pass_options)
198            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
199
200        //let path = "jit_module_after_pre_autodiff_opt.ll";
201        //self.codegen()
202        //    .module()
203        //    .print_to_file(path)
204        //    .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))
205    }
206
207    fn post_autodiff_optimisation(&mut self) -> Result<()> {
208        // remove noinline attribute from barrier function as only needed for enzyme
209        if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
210            let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
211            barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
212        }
213
214        // remove all preprocess_* functions
215        for f in self.codegen_mut().module.get_functions() {
216            if f.get_name().to_str().unwrap().starts_with("preprocess_") {
217                unsafe { f.delete() };
218            }
219        }
220
221        //self.codegen()
222        //    .module()
223        //    .print_to_file("jit_module_before_post_autodiff_opt.ll")
224        //    .unwrap();
225
226        let passes = "default<O3>";
227        let (codegen, machine) = self.codegen_and_machine_mut();
228        codegen
229            .module()
230            .run_passes(passes, machine, PassBuilderOptions::create())
231            .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
232
233        //self.codegen()
234        //    .module()
235        //    .print_to_file("jit_module_after_post_autodiff_opt.ll")
236        //    .unwrap();
237
238        Ok(())
239    }
240
241    pub fn print(&self) {
242        self.codegen().module().print_to_stderr();
243    }
244    fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
245        unsafe {
246            self.inner
247                .as_mut()
248                .get_unchecked_mut()
249                .codegen
250                .as_mut()
251                .unwrap()
252        }
253    }
254    fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
255        (
256            unsafe {
257                self.inner
258                    .as_mut()
259                    .get_unchecked_mut()
260                    .codegen
261                    .as_mut()
262                    .unwrap()
263            },
264            &self.machine,
265        )
266    }
267
268    fn codegen(&self) -> &CodeGen<'static> {
269        self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
270    }
271
272    pub fn to_dynamic_library(self, output_path: impl Into<std::path::PathBuf>) -> Result<()> {
273        use std::fs;
274        use std::process::Command;
275
276        let output_path = output_path.into();
277        let object_buffer = self.to_object()?;
278
279        // Create a temporary object file.
280        let temp_dir = std::env::temp_dir();
281        let obj_path = temp_dir.join("diffsl_temp_object.o");
282        fs::write(&obj_path, object_buffer)
283            .map_err(|e| anyhow!("Failed to write temporary object file: {}", e))?;
284
285        let lld = option_env!("DIFFSL_LLVM_LLD")
286            .ok_or_else(|| anyhow!("DIFFSL_LLVM_LLD not set by build script"))?;
287        let linker_name = std::path::Path::new(lld)
288            .file_name()
289            .and_then(|name| name.to_str())
290            .unwrap_or(lld);
291        let is_clang_driver = linker_name.starts_with("clang");
292
293        let mut command = Command::new(lld);
294        if cfg!(target_os = "windows") {
295            if is_clang_driver {
296                command.arg("-shared");
297                command.arg("-o");
298                command.arg(&output_path);
299                command.arg(&obj_path);
300            } else {
301                command.arg("-flavor").arg("link");
302                command.arg("/DLL");
303                command.arg(format!("/OUT:{}", output_path.display()));
304                command.arg(&obj_path);
305            }
306        } else if cfg!(target_os = "macos") {
307            if !lld.ends_with("ld64.lld") && !is_clang_driver {
308                command.arg("-flavor").arg("darwin");
309            }
310            let arch = if cfg!(target_arch = "aarch64") {
311                "arm64"
312            } else if cfg!(target_arch = "x86_64") {
313                "x86_64"
314            } else {
315                return Err(anyhow!("Unsupported macOS architecture for lld invocation"));
316            };
317            let deployment_target =
318                std::env::var("MACOSX_DEPLOYMENT_TARGET").unwrap_or_else(|_| "11.0".to_string());
319            if is_clang_driver {
320                command.arg("-dynamiclib");
321                command.arg("-arch");
322                command.arg(arch);
323                command.arg(format!("-mmacosx-version-min={deployment_target}"));
324                command.arg("-o");
325                command.arg(&output_path);
326                command.arg(&obj_path);
327            } else {
328                command.arg("-arch");
329                command.arg(arch);
330                command.arg("-platform_version");
331                command.arg("macos");
332                command.arg(&deployment_target);
333                command.arg(&deployment_target);
334                command.arg("-dylib");
335                command.arg("-o");
336                command.arg(&output_path);
337                command.arg(&obj_path);
338            }
339        } else {
340            if is_clang_driver {
341                command.arg("-shared");
342                command.arg("-o");
343                command.arg(&output_path);
344                command.arg(&obj_path);
345            } else {
346                command.arg("-flavor").arg("gnu");
347                command.arg("-shared");
348                command.arg("-o");
349                command.arg(&output_path);
350                command.arg(&obj_path);
351            }
352        }
353
354        let status = command
355            .status()
356            .map_err(|e| anyhow!("Failed to invoke lld: {}", e))?;
357        if !status.success() {
358            return Err(anyhow!(
359                "Dynamic library link failed with status: {}",
360                status
361            ));
362        }
363
364        let _ = fs::remove_file(&obj_path);
365        Ok(())
366    }
367}
368
369impl CodegenModule for LlvmModule {}
370
371impl CodegenModuleCompile for LlvmModule {
372    fn from_discrete_model(
373        model: &DiscreteModel,
374        options: CompilerOptions,
375        triple: Option<Triple>,
376        real_type: RealType,
377        code: Option<&str>,
378    ) -> Result<Self> {
379        let thread_dim = options.mode.thread_dim(model.state().nnz());
380        let threaded = thread_dim > 1;
381        if (unsafe { LLVMIsMultithreaded() } <= 0) {
382            return Err(anyhow!(
383                "LLVM is not compiled with multithreading support, but this codegen module requires it."
384            ));
385        }
386
387        let constants_use_jit = options.constants_use_jit;
388        let mut module = Self::new(
389            triple,
390            model,
391            threaded,
392            real_type,
393            options.debug,
394            constants_use_jit,
395        )?;
396
397        let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
398        let calc_stop = module.codegen_mut().compile_calc_stop(model, false, code)?;
399        let calc_stop_full = module.codegen_mut().compile_calc_stop(model, true, code)?;
400        let reset = module.codegen_mut().compile_reset(model, false, code)?;
401        let reset_full = module.codegen_mut().compile_reset(model, true, code)?;
402        let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
403        let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
404        let mass = module.codegen_mut().compile_mass(model, code)?;
405        let calc_out = module.codegen_mut().compile_calc_out(model, false, code)?;
406        let calc_out_full = module.codegen_mut().compile_calc_out(model, true, code)?;
407        let _set_id = module.codegen_mut().compile_set_id(model)?;
408        let _get_dims = module.codegen_mut().compile_get_dims(model)?;
409        let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
410        let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
411        let _set_constants =
412            module
413                .codegen_mut()
414                .compile_set_constants(model, code, constants_use_jit)?;
415        let tensor_info = module
416            .codegen()
417            .layout
418            .tensors()
419            .map(|(name, is_constant)| (name.to_string(), is_constant))
420            .collect::<Vec<_>>();
421        for (tensor, is_constant) in tensor_info {
422            if is_constant {
423                module
424                    .codegen_mut()
425                    .compile_get_constant(model, tensor.as_str())?;
426            } else {
427                module
428                    .codegen_mut()
429                    .compile_get_tensor(model, tensor.as_str())?;
430            }
431        }
432
433        module.pre_autodiff_optimisation()?;
434
435        module.codegen_mut().compile_gradient(
436            set_u0,
437            &[
438                CompileGradientArgType::DupNoNeed,
439                CompileGradientArgType::DupNoNeed,
440                CompileGradientArgType::Const,
441                CompileGradientArgType::Const,
442            ],
443            CompileMode::Forward,
444            "set_u0_grad",
445        )?;
446
447        module.codegen_mut().compile_gradient(
448            rhs,
449            &[
450                CompileGradientArgType::Const,
451                CompileGradientArgType::DupNoNeed,
452                CompileGradientArgType::DupNoNeed,
453                CompileGradientArgType::DupNoNeed,
454                CompileGradientArgType::Const,
455                CompileGradientArgType::Const,
456            ],
457            CompileMode::Forward,
458            "rhs_grad",
459        )?;
460
461        module.codegen_mut().compile_gradient(
462            reset,
463            &[
464                CompileGradientArgType::Const,
465                CompileGradientArgType::DupNoNeed,
466                CompileGradientArgType::DupNoNeed,
467                CompileGradientArgType::DupNoNeed,
468                CompileGradientArgType::Const,
469                CompileGradientArgType::Const,
470            ],
471            CompileMode::Forward,
472            "reset_grad",
473        )?;
474
475        module.codegen_mut().compile_gradient(
476            calc_stop,
477            &[
478                CompileGradientArgType::Const,
479                CompileGradientArgType::DupNoNeed,
480                CompileGradientArgType::DupNoNeed,
481                CompileGradientArgType::DupNoNeed,
482                CompileGradientArgType::Const,
483                CompileGradientArgType::Const,
484            ],
485            CompileMode::Forward,
486            "calc_stop_grad",
487        )?;
488
489        module.codegen_mut().compile_gradient(
490            calc_out,
491            &[
492                CompileGradientArgType::Const,
493                CompileGradientArgType::DupNoNeed,
494                CompileGradientArgType::DupNoNeed,
495                CompileGradientArgType::DupNoNeed,
496                CompileGradientArgType::Const,
497                CompileGradientArgType::Const,
498            ],
499            CompileMode::Forward,
500            "calc_out_grad",
501        )?;
502        module.codegen_mut().compile_gradient(
503            set_inputs,
504            &[
505                CompileGradientArgType::DupNoNeed,
506                CompileGradientArgType::DupNoNeed,
507                CompileGradientArgType::Const,
508            ],
509            CompileMode::Forward,
510            "set_inputs_grad",
511        )?;
512
513        module.codegen_mut().compile_gradient(
514            set_u0,
515            &[
516                CompileGradientArgType::DupNoNeed,
517                CompileGradientArgType::DupNoNeed,
518                CompileGradientArgType::Const,
519                CompileGradientArgType::Const,
520            ],
521            CompileMode::Reverse,
522            "set_u0_rgrad",
523        )?;
524
525        module.codegen_mut().compile_gradient(
526            mass,
527            &[
528                CompileGradientArgType::Const,
529                CompileGradientArgType::DupNoNeed,
530                CompileGradientArgType::DupNoNeed,
531                CompileGradientArgType::DupNoNeed,
532                CompileGradientArgType::Const,
533                CompileGradientArgType::Const,
534            ],
535            CompileMode::Reverse,
536            "mass_rgrad",
537        )?;
538
539        module.codegen_mut().compile_gradient(
540            rhs,
541            &[
542                CompileGradientArgType::Const,
543                CompileGradientArgType::DupNoNeed,
544                CompileGradientArgType::DupNoNeed,
545                CompileGradientArgType::DupNoNeed,
546                CompileGradientArgType::Const,
547                CompileGradientArgType::Const,
548            ],
549            CompileMode::Reverse,
550            "rhs_rgrad",
551        )?;
552        module.codegen_mut().compile_gradient(
553            reset,
554            &[
555                CompileGradientArgType::Const,
556                CompileGradientArgType::DupNoNeed,
557                CompileGradientArgType::DupNoNeed,
558                CompileGradientArgType::DupNoNeed,
559                CompileGradientArgType::Const,
560                CompileGradientArgType::Const,
561            ],
562            CompileMode::Reverse,
563            "reset_rgrad",
564        )?;
565        module.codegen_mut().compile_gradient(
566            calc_stop,
567            &[
568                CompileGradientArgType::Const,
569                CompileGradientArgType::DupNoNeed,
570                CompileGradientArgType::DupNoNeed,
571                CompileGradientArgType::DupNoNeed,
572                CompileGradientArgType::Const,
573                CompileGradientArgType::Const,
574            ],
575            CompileMode::Reverse,
576            "calc_stop_rgrad",
577        )?;
578        module.codegen_mut().compile_gradient(
579            calc_out,
580            &[
581                CompileGradientArgType::Const,
582                CompileGradientArgType::DupNoNeed,
583                CompileGradientArgType::DupNoNeed,
584                CompileGradientArgType::DupNoNeed,
585                CompileGradientArgType::Const,
586                CompileGradientArgType::Const,
587            ],
588            CompileMode::Reverse,
589            "calc_out_rgrad",
590        )?;
591
592        module.codegen_mut().compile_gradient(
593            set_inputs,
594            &[
595                CompileGradientArgType::DupNoNeed,
596                CompileGradientArgType::DupNoNeed,
597                CompileGradientArgType::Const,
598            ],
599            CompileMode::Reverse,
600            "set_inputs_rgrad",
601        )?;
602
603        module.codegen_mut().compile_gradient(
604            rhs_full,
605            &[
606                CompileGradientArgType::Const,
607                CompileGradientArgType::Const,
608                CompileGradientArgType::DupNoNeed,
609                CompileGradientArgType::DupNoNeed,
610                CompileGradientArgType::Const,
611                CompileGradientArgType::Const,
612            ],
613            CompileMode::ForwardSens,
614            "rhs_sgrad",
615        )?;
616
617        module.codegen_mut().compile_gradient(
618            reset_full,
619            &[
620                CompileGradientArgType::Const,
621                CompileGradientArgType::Const,
622                CompileGradientArgType::DupNoNeed,
623                CompileGradientArgType::DupNoNeed,
624                CompileGradientArgType::Const,
625                CompileGradientArgType::Const,
626            ],
627            CompileMode::ForwardSens,
628            "reset_sgrad",
629        )?;
630
631        module.codegen_mut().compile_gradient(
632            calc_stop_full,
633            &[
634                CompileGradientArgType::Const,
635                CompileGradientArgType::Const,
636                CompileGradientArgType::DupNoNeed,
637                CompileGradientArgType::DupNoNeed,
638                CompileGradientArgType::Const,
639                CompileGradientArgType::Const,
640            ],
641            CompileMode::ForwardSens,
642            "calc_stop_sgrad",
643        )?;
644
645        module.codegen_mut().compile_gradient(
646            set_u0,
647            &[
648                CompileGradientArgType::DupNoNeed,
649                CompileGradientArgType::DupNoNeed,
650                CompileGradientArgType::Const,
651                CompileGradientArgType::Const,
652            ],
653            CompileMode::ForwardSens,
654            "set_u0_sgrad",
655        )?;
656
657        module.codegen_mut().compile_gradient(
658            calc_out_full,
659            &[
660                CompileGradientArgType::Const,
661                CompileGradientArgType::Const,
662                CompileGradientArgType::DupNoNeed,
663                CompileGradientArgType::DupNoNeed,
664                CompileGradientArgType::Const,
665                CompileGradientArgType::Const,
666            ],
667            CompileMode::ForwardSens,
668            "calc_out_sgrad",
669        )?;
670        module.codegen_mut().compile_gradient(
671            calc_out_full,
672            &[
673                CompileGradientArgType::Const,
674                CompileGradientArgType::Const,
675                CompileGradientArgType::DupNoNeed,
676                CompileGradientArgType::DupNoNeed,
677                CompileGradientArgType::Const,
678                CompileGradientArgType::Const,
679            ],
680            CompileMode::ReverseSens,
681            "calc_out_srgrad",
682        )?;
683
684        module.codegen_mut().compile_gradient(
685            rhs_full,
686            &[
687                CompileGradientArgType::Const,
688                CompileGradientArgType::Const,
689                CompileGradientArgType::DupNoNeed,
690                CompileGradientArgType::DupNoNeed,
691                CompileGradientArgType::Const,
692                CompileGradientArgType::Const,
693            ],
694            CompileMode::ReverseSens,
695            "rhs_srgrad",
696        )?;
697
698        module.codegen_mut().compile_gradient(
699            reset_full,
700            &[
701                CompileGradientArgType::Const,
702                CompileGradientArgType::Const,
703                CompileGradientArgType::DupNoNeed,
704                CompileGradientArgType::DupNoNeed,
705                CompileGradientArgType::Const,
706                CompileGradientArgType::Const,
707            ],
708            CompileMode::ReverseSens,
709            "reset_srgrad",
710        )?;
711
712        module.codegen_mut().compile_gradient(
713            calc_stop_full,
714            &[
715                CompileGradientArgType::Const,
716                CompileGradientArgType::Const,
717                CompileGradientArgType::DupNoNeed,
718                CompileGradientArgType::DupNoNeed,
719                CompileGradientArgType::Const,
720                CompileGradientArgType::Const,
721            ],
722            CompileMode::ReverseSens,
723            "calc_stop_srgrad",
724        )?;
725
726        module.post_autodiff_optimisation()?;
727        Ok(module)
728    }
729}
730
731impl CodegenModuleEmit for LlvmModule {
732    fn to_object(&self) -> Result<Vec<u8>> {
733        let module = self.codegen().module();
734        //module.print_to_stderr();
735        let buffer = self
736            .machine
737            .write_to_memory_buffer(module, FileType::Object)
738            .unwrap()
739            .as_slice()
740            .to_vec();
741        Ok(buffer)
742    }
743}
744impl CodegenModuleJit for LlvmModule {
745    fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
746        let ee = self
747            .codegen()
748            .module()
749            .create_jit_execution_engine(OptimizationLevel::Default)
750            .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
751
752        //let path = "jit_module.ll";
753        //self.codegen()
754        //    .module()
755        //    .print_to_file(path)
756        //    .map_err(|e| anyhow!("Failed to print module to file: {:?}", e))?;
757        let module = self.codegen().module();
758        let mut symbols = HashMap::new();
759        for function in module.get_functions() {
760            let name = function.get_name().to_str().unwrap();
761            let address = ee.get_function_address(name);
762            if let Ok(address) = address {
763                symbols.insert(name.to_string(), address as *const u8);
764            }
765        }
766        Ok(symbols)
767    }
768}
769
770struct Globals<'ctx> {
771    indices: Option<GlobalValue<'ctx>>,
772    constants: Option<GlobalValue<'ctx>>,
773    thread_counter: Option<GlobalValue<'ctx>>,
774    model_index: GlobalValue<'ctx>,
775}
776
777impl<'ctx> Globals<'ctx> {
778    fn new(
779        layout: &DataLayout,
780        module: &Module<'ctx>,
781        int_type: IntType<'ctx>,
782        real_type: FloatType<'ctx>,
783        threaded: bool,
784        constants_use_jit: bool,
785    ) -> Self {
786        let thread_counter = if threaded {
787            let tc = module.add_global(
788                int_type,
789                Some(AddressSpace::default()),
790                "enzyme_const_thread_counter",
791            );
792            // todo: for some reason this doesn't make enzyme think it's inactive
793            // but using enzyme_const in the name does
794            // todo: also, adding this metadata causes the print of the module to segfault,
795            // so maybe a bug in inkwell
796            //let md_string = context.metadata_string("enzyme_inactive");
797            //tc.set_metadata(md_string, 0);
798            let tc_value = int_type.const_zero();
799            tc.set_visibility(GlobalVisibility::Hidden);
800            tc.set_initializer(&tc_value.as_basic_value_enum());
801            Some(tc)
802        } else {
803            None
804        };
805        let constants = if layout.constants().is_empty() {
806            None
807        } else {
808            let constants_array_type =
809                real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
810            let constants = module.add_global(
811                constants_array_type,
812                Some(AddressSpace::default()),
813                "enzyme_const_constants",
814            );
815            constants.set_visibility(GlobalVisibility::Hidden);
816            constants.set_constant(!constants_use_jit);
817            let constants_array_values = layout
818                .constants()
819                .iter()
820                .map(|&value| real_type.const_float(value))
821                .collect::<Vec<_>>();
822            let constants_value = real_type.const_array(constants_array_values.as_slice());
823            constants.set_initializer(&constants_value);
824            Some(constants)
825        };
826        let indices = if layout.indices().is_empty() {
827            None
828        } else {
829            let indices_array_type =
830                int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
831            let indices_array_values = layout
832                .indices()
833                .iter()
834                .map(|&i| int_type.const_int(i as u64, true))
835                .collect::<Vec<IntValue>>();
836            let indices_value = int_type.const_array(indices_array_values.as_slice());
837            let indices = module.add_global(
838                indices_array_type,
839                Some(AddressSpace::default()),
840                "enzyme_const_indices",
841            );
842            indices.set_constant(true);
843            indices.set_visibility(GlobalVisibility::Hidden);
844            indices.set_initializer(&indices_value);
845            Some(indices)
846        };
847        let model_index = module.add_global(
848            int_type,
849            Some(AddressSpace::default()),
850            "enzyme_const_model_index",
851        );
852        model_index.set_visibility(GlobalVisibility::Hidden);
853        model_index.set_constant(false);
854        model_index.set_initializer(&int_type.const_zero());
855        Self {
856            indices,
857            thread_counter,
858            constants,
859            model_index,
860        }
861    }
862}
863
864pub enum CompileGradientArgType {
865    Const,
866    Dup,
867    DupNoNeed,
868}
869
870pub enum CompileMode {
871    Forward,
872    ForwardSens,
873    Reverse,
874    ReverseSens,
875}
876
877pub struct CodeGen<'ctx> {
878    context: &'ctx inkwell::context::Context,
879    module: Module<'ctx>,
880    builder: Builder<'ctx>,
881    dibuilder: Option<DebugInfoBuilder<'ctx>>,
882    compile_unit: Option<DICompileUnit<'ctx>>,
883    variables: HashMap<String, PointerValue<'ctx>>,
884    functions: HashMap<String, FunctionValue<'ctx>>,
885    fn_value_opt: Option<FunctionValue<'ctx>>,
886    tensor_ptr_opt: Option<PointerValue<'ctx>>,
887    diffsl_real_type: RealType,
888    real_type: FloatType<'ctx>,
889    real_ptr_type: PointerType<'ctx>,
890    int_type: IntType<'ctx>,
891    int_ptr_type: PointerType<'ctx>,
892    ptr_size_bits: u64,
893    int_size_bits: u64,
894    real_size_bits: u64,
895    layout: DataLayout,
896    globals: Globals<'ctx>,
897    threaded: bool,
898    _ee: Option<ExecutionEngine<'ctx>>,
899}
900
901unsafe extern "C" fn fwd_handler(
902    _builder: LLVMBuilderRef,
903    _call_instruction: LLVMValueRef,
904    _gutils: *mut GradientUtils,
905    _dcall: *mut LLVMValueRef,
906    _normal_return: *mut LLVMValueRef,
907    _shadow_return: *mut LLVMValueRef,
908) -> u8 {
909    1
910}
911
912unsafe extern "C" fn rev_handler(
913    builder: LLVMBuilderRef,
914    call_instruction: LLVMValueRef,
915    gutils: *mut DiffeGradientUtils,
916    _tape: LLVMValueRef,
917) {
918    let call_block = LLVMGetInstructionParent(call_instruction);
919    let call_function = LLVMGetBasicBlockParent(call_block);
920    let module = LLVMGetGlobalParent(call_function);
921    let name_c_str = CString::new("barrier_grad").unwrap();
922    let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
923    let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
924    let barrier_num = LLVMGetArgOperand(call_instruction, 0);
925    let total_barriers = LLVMGetArgOperand(call_instruction, 1);
926    let thread_count = LLVMGetArgOperand(call_instruction, 2);
927    let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
928    let total_barriers =
929        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
930    let thread_count =
931        EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
932    let mut args = [barrier_num, total_barriers, thread_count];
933    let name_c_str = CString::new("").unwrap();
934    LLVMBuildCall2(
935        builder,
936        barrier_func_type,
937        barrier_func,
938        args.as_mut_ptr(),
939        args.len() as u32,
940        name_c_str.as_ptr(),
941    );
942}
943
944#[allow(dead_code)]
945enum PrintValue<'ctx> {
946    Real(FloatValue<'ctx>),
947    Int(IntValue<'ctx>),
948}
949
950impl<'ctx> CodeGen<'ctx> {
951    #[allow(clippy::too_many_arguments)]
952    pub fn new(
953        model: &DiscreteModel,
954        context: &'ctx inkwell::context::Context,
955        diffsl_real_type: RealType,
956        real_type: FloatType<'ctx>,
957        int_type: IntType<'ctx>,
958        threaded: bool,
959        ptr_size_bits: u64,
960        real_size_bits: u64,
961        int_size_bits: u64,
962        debug: bool,
963        constants_use_jit: bool,
964    ) -> Result<Self> {
965        let builder = context.create_builder();
966        let layout = DataLayout::new(model);
967        let module = context.create_module(model.name());
968        let (dibuilder, compile_unit) = if debug {
969            let (dib, dic) = module.create_debug_info_builder(
970                true,
971                inkwell::debug_info::DWARFSourceLanguage::C,
972                model.name(),
973                ".",
974                "diffsl compiler",
975                true,
976                "",
977                0,
978                "",
979                inkwell::debug_info::DWARFEmissionKind::Full,
980                0,
981                false,
982                false,
983                "",
984                "",
985            );
986            (Some(dib), Some(dic))
987        } else {
988            (None, None)
989        };
990        let globals = Globals::new(
991            &layout,
992            &module,
993            int_type,
994            real_type,
995            threaded,
996            constants_use_jit,
997        );
998        let real_ptr_type = Self::pointer_type(context, real_type.into());
999        let int_ptr_type = Self::pointer_type(context, int_type.into());
1000        let mut ret = Self {
1001            context,
1002            module,
1003            builder,
1004            dibuilder,
1005            compile_unit,
1006            real_type,
1007            real_ptr_type,
1008            variables: HashMap::new(),
1009            functions: HashMap::new(),
1010            fn_value_opt: None,
1011            tensor_ptr_opt: None,
1012            layout,
1013            diffsl_real_type,
1014            int_type,
1015            int_ptr_type,
1016            globals,
1017            threaded,
1018            ptr_size_bits,
1019            real_size_bits,
1020            int_size_bits,
1021            _ee: None,
1022        };
1023        if threaded {
1024            ret.compile_barrier_init()?;
1025            ret.compile_barrier()?;
1026            ret.compile_barrier_grad()?;
1027            // todo: think I can remove this unless I want to call enzyme using a llvm pass
1028            //ret.globals.add_registered_barrier(ret.context, &ret.module);
1029        }
1030        Ok(ret)
1031    }
1032
1033    fn start_function(
1034        &mut self,
1035        function: FunctionValue<'ctx>,
1036        _code: Option<&str>,
1037    ) -> BasicBlock<'ctx> {
1038        let basic_block = self.context.append_basic_block(function, "entry");
1039        self.fn_value_opt = Some(function);
1040        self.builder.position_at_end(basic_block);
1041
1042        if let Some(dibuilder) = &self.dibuilder {
1043            let scope = self
1044                .fn_value_opt
1045                .unwrap()
1046                .get_subprogram()
1047                .unwrap()
1048                .as_debug_info_scope();
1049            let loc = dibuilder.create_debug_location(self.context, 0, 0, scope, None);
1050            self.builder.set_current_debug_location(loc);
1051        }
1052        basic_block
1053    }
1054
1055    fn add_function(
1056        &mut self,
1057        name: &str,
1058        arg_names: &[&str],
1059        arg_types: &[BasicMetadataTypeEnum<'ctx>],
1060        linkage: Option<Linkage>,
1061        is_real_return: bool,
1062    ) -> FunctionValue<'ctx> {
1063        let function_type = if is_real_return {
1064            self.real_type.fn_type(arg_types, false)
1065        } else {
1066            self.context.void_type().fn_type(arg_types, false)
1067        };
1068        let function = self.module.add_function(name, function_type, linkage);
1069        if let (Some(dibuilder), Some(compile_unit)) = (&self.dibuilder, &self.compile_unit) {
1070            let ditypes = arg_names
1071                .iter()
1072                .zip(arg_types.iter())
1073                .map(|(&name, &ty)| {
1074                    let size_in_bits = if ty.is_float_type() {
1075                        self.real_size_bits
1076                    } else if ty.is_int_type() {
1077                        self.int_size_bits
1078                    } else if ty.is_pointer_type() {
1079                        self.ptr_size_bits
1080                    } else {
1081                        unreachable!("Unsupported argument type for debug info")
1082                    };
1083                    dibuilder
1084                        .create_basic_type(
1085                            name,
1086                            size_in_bits,
1087                            0x00,
1088                            <DIFlags as DIFlagsConstants>::PUBLIC,
1089                        )
1090                        .unwrap()
1091                        .as_type()
1092                })
1093                .collect::<Vec<_>>();
1094            let subroutine_type = dibuilder.create_subroutine_type(
1095                compile_unit.get_file(),
1096                None,
1097                &ditypes,
1098                <DIFlags as DIFlagsConstants>::PUBLIC,
1099            );
1100            let func_scope = dibuilder.create_function(
1101                compile_unit.as_debug_info_scope(),
1102                name,
1103                None,
1104                compile_unit.get_file(),
1105                0,
1106                subroutine_type,
1107                true,
1108                true,
1109                0,
1110                <DIFlags as DIFlagsConstants>::PUBLIC,
1111                true,
1112            );
1113            function.set_subprogram(func_scope);
1114        }
1115        self.functions.insert(name.to_owned(), function);
1116        function
1117    }
1118
1119    #[allow(dead_code)]
1120    fn compile_print_value(
1121        &mut self,
1122        name: &str,
1123        value: PrintValue<'ctx>,
1124    ) -> Result<CallSiteValue<'_>> {
1125        // get printf function or declare it if it doesn't exist
1126        let printf = match self.module.get_function("printf") {
1127            Some(f) => f,
1128            // int printf(const char *format, ...)
1129            None => self.add_function(
1130                "printf",
1131                &["format"],
1132                &[self.int_ptr_type.into()],
1133                Some(Linkage::External),
1134                false,
1135            ),
1136        };
1137        let (format_str, format_str_name) = match value {
1138            PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
1139            PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
1140        };
1141        // change format_str to c string
1142        let format_str = CString::new(format_str).unwrap();
1143        // if format_str_name doesn not already exist as a global, add it
1144        let format_str_global = match self.module.get_global(format_str_name.as_str()) {
1145            Some(g) => g,
1146            None => {
1147                let format_str = self.context.const_string(format_str.as_bytes(), true);
1148                let fmt_str =
1149                    self.module
1150                        .add_global(format_str.get_type(), None, format_str_name.as_str());
1151                fmt_str.set_initializer(&format_str);
1152                fmt_str.set_visibility(GlobalVisibility::Hidden);
1153                fmt_str
1154            }
1155        };
1156        // call printf with the format string and the value
1157        let format_str_ptr = self.builder.build_pointer_cast(
1158            format_str_global.as_pointer_value(),
1159            self.int_ptr_type,
1160            "format_str_ptr",
1161        )?;
1162        let value: BasicMetadataValueEnum = match value {
1163            PrintValue::Real(v) => v.into(),
1164            PrintValue::Int(v) => v.into(),
1165        };
1166        self.builder
1167            .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
1168            .map_err(|e| anyhow!("Error building call to printf: {}", e))
1169    }
1170
1171    fn compile_set_constants(
1172        &mut self,
1173        model: &DiscreteModel,
1174        code: Option<&str>,
1175        constants_use_jit: bool,
1176    ) -> Result<FunctionValue<'ctx>> {
1177        self.clear();
1178        let fn_arg_names = &["thread_id", "thread_dim"];
1179        let function = self.add_function(
1180            "set_constants",
1181            fn_arg_names,
1182            &[self.int_type.into(), self.int_type.into()],
1183            None,
1184            false,
1185        );
1186        let _basic_block = self.start_function(function, code);
1187
1188        for (i, arg) in function.get_param_iter().enumerate() {
1189            let name = fn_arg_names[i];
1190            let alloca = self.function_arg_alloca(name, arg);
1191            self.insert_param(name, alloca);
1192        }
1193
1194        self.insert_indices();
1195        self.insert_constants(model);
1196
1197        if constants_use_jit {
1198            let mut nbarriers = 0;
1199            let total_barriers = (model.constant_defns().len()) as u64;
1200            let total_barriers_val = self.int_type.const_int(total_barriers, false);
1201            #[allow(clippy::explicit_counter_loop)]
1202            for a in model.constant_defns() {
1203                self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
1204                let barrier_num = self.int_type.const_int(nbarriers + 1, false);
1205                self.jit_compile_call_barrier(barrier_num, total_barriers_val);
1206                nbarriers += 1;
1207            }
1208        }
1209
1210        self.builder.build_return(None)?;
1211
1212        if function.verify(true) {
1213            Ok(function)
1214        } else {
1215            function.print_to_stderr();
1216            self.module.print_to_stderr();
1217            unsafe {
1218                function.delete();
1219            }
1220            Err(anyhow!("Invalid generated function."))
1221        }
1222    }
1223
1224    fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
1225        self.clear();
1226        let function = self.add_function("barrier_init", &[], &[], None, false);
1227        let _entry_block = self.start_function(function, None);
1228
1229        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1230        self.builder
1231            .build_store(thread_counter, self.int_type.const_zero())?;
1232
1233        self.builder.build_return(None)?;
1234
1235        if function.verify(true) {
1236            Ok(function)
1237        } else {
1238            function.print_to_stderr();
1239            unsafe {
1240                function.delete();
1241            }
1242            Err(anyhow!("Invalid generated function."))
1243        }
1244    }
1245
1246    fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
1247        self.clear();
1248        let function = self.add_function(
1249            "barrier",
1250            &["barrier_num", "total_barriers", "thread_count"],
1251            &[
1252                self.int_type.into(),
1253                self.int_type.into(),
1254                self.int_type.into(),
1255            ],
1256            None,
1257            false,
1258        );
1259        let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
1260        let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
1261        function.add_attribute(AttributeLoc::Function, noinline);
1262
1263        let _entry_block = self.start_function(function, None);
1264        let increment_block = self.context.append_basic_block(function, "increment");
1265        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1266        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1267
1268        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1269        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1270        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1271        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1272
1273        let nbarrier_equals_total_barriers = self
1274            .builder
1275            .build_int_compare(
1276                IntPredicate::EQ,
1277                barrier_num,
1278                total_barriers,
1279                "nbarrier_equals_total_barriers",
1280            )
1281            .unwrap();
1282        // branch to barrier_done if nbarrier == total_barriers
1283        self.builder.build_conditional_branch(
1284            nbarrier_equals_total_barriers,
1285            barrier_done_block,
1286            increment_block,
1287        )?;
1288        self.builder.position_at_end(increment_block);
1289
1290        let barrier_num_times_thread_count = self
1291            .builder
1292            .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
1293            .unwrap();
1294
1295        // Atomically increment the barrier counter
1296        let i32_type = self.context.i32_type();
1297        let one = i32_type.const_int(1, false);
1298        self.builder.build_atomicrmw(
1299            AtomicRMWBinOp::Add,
1300            thread_counter,
1301            one,
1302            AtomicOrdering::Monotonic,
1303        )?;
1304
1305        // wait_loop:
1306        self.builder.build_unconditional_branch(wait_loop_block)?;
1307        self.builder.position_at_end(wait_loop_block);
1308
1309        let current_value = self
1310            .builder
1311            .build_load(i32_type, thread_counter, "current_value")?
1312            .into_int_value();
1313
1314        current_value
1315            .as_instruction_value()
1316            .unwrap()
1317            .set_atomic_ordering(AtomicOrdering::Monotonic)
1318            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1319
1320        let all_threads_done = self.builder.build_int_compare(
1321            IntPredicate::UGE,
1322            current_value,
1323            barrier_num_times_thread_count,
1324            "all_threads_done",
1325        )?;
1326
1327        self.builder.build_conditional_branch(
1328            all_threads_done,
1329            barrier_done_block,
1330            wait_loop_block,
1331        )?;
1332        self.builder.position_at_end(barrier_done_block);
1333
1334        self.builder.build_return(None)?;
1335
1336        if function.verify(true) {
1337            Ok(function)
1338        } else {
1339            function.print_to_stderr();
1340            unsafe {
1341                function.delete();
1342            }
1343            Err(anyhow!("Invalid generated function."))
1344        }
1345    }
1346
1347    fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
1348        self.clear();
1349        let function = self.add_function(
1350            "barrier_grad",
1351            &["barrier_num", "total_barriers", "thread_count"],
1352            &[
1353                self.int_type.into(),
1354                self.int_type.into(),
1355                self.int_type.into(),
1356            ],
1357            None,
1358            false,
1359        );
1360
1361        let _entry_block = self.start_function(function, None);
1362        let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1363        let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1364
1365        let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1366        let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1367        let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1368        let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1369
1370        let twice_total_barriers = self
1371            .builder
1372            .build_int_mul(
1373                total_barriers,
1374                self.int_type.const_int(2, false),
1375                "twice_total_barriers",
1376            )
1377            .unwrap();
1378        let twice_total_barriers_minus_barrier_num = self
1379            .builder
1380            .build_int_sub(
1381                twice_total_barriers,
1382                barrier_num,
1383                "twice_total_barriers_minus_barrier_num",
1384            )
1385            .unwrap();
1386        let twice_total_barriers_minus_barrier_num_times_thread_count = self
1387            .builder
1388            .build_int_mul(
1389                twice_total_barriers_minus_barrier_num,
1390                thread_count,
1391                "twice_total_barriers_minus_barrier_num_times_thread_count",
1392            )
1393            .unwrap();
1394
1395        // Atomically increment the barrier counter
1396        let i32_type = self.context.i32_type();
1397        let one = i32_type.const_int(1, false);
1398        self.builder.build_atomicrmw(
1399            AtomicRMWBinOp::Add,
1400            thread_counter,
1401            one,
1402            AtomicOrdering::Monotonic,
1403        )?;
1404
1405        // wait_loop:
1406        self.builder.build_unconditional_branch(wait_loop_block)?;
1407        self.builder.position_at_end(wait_loop_block);
1408
1409        let current_value = self
1410            .builder
1411            .build_load(i32_type, thread_counter, "current_value")?
1412            .into_int_value();
1413        current_value
1414            .as_instruction_value()
1415            .unwrap()
1416            .set_atomic_ordering(AtomicOrdering::Monotonic)
1417            .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1418
1419        let all_threads_done = self.builder.build_int_compare(
1420            IntPredicate::UGE,
1421            current_value,
1422            twice_total_barriers_minus_barrier_num_times_thread_count,
1423            "all_threads_done",
1424        )?;
1425
1426        self.builder.build_conditional_branch(
1427            all_threads_done,
1428            barrier_done_block,
1429            wait_loop_block,
1430        )?;
1431        self.builder.position_at_end(barrier_done_block);
1432
1433        self.builder.build_return(None)?;
1434
1435        if function.verify(true) {
1436            Ok(function)
1437        } else {
1438            function.print_to_stderr();
1439            unsafe {
1440                function.delete();
1441            }
1442            Err(anyhow!("Invalid generated function."))
1443        }
1444    }
1445
1446    fn jit_compile_call_barrier(
1447        &mut self,
1448        barrier_num: IntValue<'ctx>,
1449        total_barriers: IntValue<'ctx>,
1450    ) {
1451        if !self.threaded {
1452            return;
1453        }
1454        let thread_dim = self.get_param("thread_dim");
1455        let thread_dim = self
1456            .builder
1457            .build_load(self.int_type, *thread_dim, "thread_dim")
1458            .unwrap()
1459            .into_int_value();
1460        let barrier = self.get_function("barrier").unwrap();
1461        self.builder
1462            .build_call(
1463                barrier,
1464                &[
1465                    BasicMetadataValueEnum::IntValue(barrier_num),
1466                    BasicMetadataValueEnum::IntValue(total_barriers),
1467                    BasicMetadataValueEnum::IntValue(thread_dim),
1468                ],
1469                "barrier",
1470            )
1471            .unwrap();
1472    }
1473
1474    fn jit_threading_limits(
1475        &mut self,
1476        size: IntValue<'ctx>,
1477    ) -> Result<(
1478        IntValue<'ctx>,
1479        IntValue<'ctx>,
1480        BasicBlock<'ctx>,
1481        BasicBlock<'ctx>,
1482    )> {
1483        let one = self.int_type.const_int(1, false);
1484        let thread_id = self.get_param("thread_id");
1485        let thread_id = self
1486            .builder
1487            .build_load(self.int_type, *thread_id, "thread_id")
1488            .unwrap()
1489            .into_int_value();
1490        let thread_dim = self.get_param("thread_dim");
1491        let thread_dim = self
1492            .builder
1493            .build_load(self.int_type, *thread_dim, "thread_dim")
1494            .unwrap()
1495            .into_int_value();
1496
1497        // start index is i * size / thread_dim
1498        let i_times_size = self
1499            .builder
1500            .build_int_mul(thread_id, size, "i_times_size")?;
1501        let start = self
1502            .builder
1503            .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1504
1505        // the ending index for thread i is (i+1) * size / thread_dim
1506        let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1507        let i_plus_one_times_size =
1508            self.builder
1509                .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1510        let end = self
1511            .builder
1512            .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1513
1514        let test_done = self.builder.get_insert_block().unwrap();
1515        let next_block = self
1516            .context
1517            .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1518        self.builder.position_at_end(next_block);
1519
1520        Ok((start, end, test_done, next_block))
1521    }
1522
1523    fn jit_end_threading(
1524        &mut self,
1525        start: IntValue<'ctx>,
1526        end: IntValue<'ctx>,
1527        test_done: BasicBlock<'ctx>,
1528        next: BasicBlock<'ctx>,
1529    ) -> Result<()> {
1530        let exit = self
1531            .context
1532            .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1533        self.builder.build_unconditional_branch(exit)?;
1534        self.builder.position_at_end(test_done);
1535        // done if start == end
1536        let done = self
1537            .builder
1538            .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1539        self.builder.build_conditional_branch(done, exit, next)?;
1540        self.builder.position_at_end(exit);
1541        Ok(())
1542    }
1543
1544    pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1545        self.module.write_bitcode_to_path(path);
1546    }
1547
1548    fn insert_constants(&mut self, model: &DiscreteModel) {
1549        if let Some(constants) = self.globals.constants.as_ref() {
1550            self.insert_param("constants", constants.as_pointer_value());
1551            for tensor in model.constant_defns() {
1552                self.insert_tensor(tensor, true);
1553            }
1554        }
1555    }
1556
1557    fn insert_data(&mut self, model: &DiscreteModel) {
1558        self.insert_model_index();
1559        self.insert_constants(model);
1560
1561        if let Some(input) = model.input() {
1562            self.insert_tensor(input, false);
1563        }
1564        for tensor in model.input_dep_defns() {
1565            self.insert_tensor(tensor, false);
1566        }
1567        for tensor in model.time_dep_defns() {
1568            self.insert_tensor(tensor, false);
1569        }
1570        for tensor in model.state_dep_defns() {
1571            self.insert_tensor(tensor, false);
1572        }
1573        for tensor in model.state_dep_post_f_defns() {
1574            self.insert_tensor(tensor, false);
1575        }
1576    }
1577
1578    fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1579        context.ptr_type(AddressSpace::default())
1580    }
1581
1582    fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1583        context.ptr_type(AddressSpace::default())
1584    }
1585
1586    fn insert_indices(&mut self) {
1587        if let Some(indices) = self.globals.indices.as_ref() {
1588            let i32_type = self.context.i32_type();
1589            let zero = i32_type.const_int(0, false);
1590            let ptr = unsafe {
1591                indices
1592                    .as_pointer_value()
1593                    .const_in_bounds_gep(i32_type, &[zero])
1594            };
1595            self.variables.insert("indices".to_owned(), ptr);
1596        }
1597    }
1598
1599    fn insert_model_index(&mut self) {
1600        self.insert_param("model_index", self.globals.model_index.as_pointer_value());
1601    }
1602
1603    fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1604        self.variables.insert(name.to_owned(), value);
1605    }
1606
1607    fn build_gep<T: BasicType<'ctx>>(
1608        &self,
1609        ty: T,
1610        ptr: PointerValue<'ctx>,
1611        ordered_indexes: &[IntValue<'ctx>],
1612        name: &str,
1613    ) -> Result<PointerValue<'ctx>> {
1614        unsafe {
1615            self.builder
1616                .build_gep(ty, ptr, ordered_indexes, name)
1617                .map_err(|e| e.into())
1618        }
1619    }
1620
1621    fn build_load<T: BasicType<'ctx>>(
1622        &self,
1623        ty: T,
1624        ptr: PointerValue<'ctx>,
1625        name: &str,
1626    ) -> Result<BasicValueEnum<'ctx>> {
1627        self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1628    }
1629
1630    fn get_ptr_to_index<T: BasicType<'ctx>>(
1631        builder: &Builder<'ctx>,
1632        ty: T,
1633        ptr: &PointerValue<'ctx>,
1634        index: IntValue<'ctx>,
1635        name: &str,
1636    ) -> PointerValue<'ctx> {
1637        unsafe {
1638            builder
1639                .build_in_bounds_gep(ty, *ptr, &[index], name)
1640                .unwrap()
1641        }
1642    }
1643
1644    fn insert_state(&mut self, u: &Tensor) {
1645        let mut data_index = 0;
1646        for blk in u.elmts() {
1647            if let Some(name) = blk.name() {
1648                let ptr = self.variables.get("u").unwrap();
1649                let i = self
1650                    .context
1651                    .i32_type()
1652                    .const_int(data_index.try_into().unwrap(), false);
1653                let alloca = Self::get_ptr_to_index(
1654                    &self.create_entry_block_builder(),
1655                    self.real_type,
1656                    ptr,
1657                    i,
1658                    blk.name().unwrap(),
1659                );
1660                self.variables.insert(name.to_owned(), alloca);
1661            }
1662            data_index += blk.nnz();
1663        }
1664    }
1665    fn insert_dot_state(&mut self, dudt: &Tensor) {
1666        let mut data_index = 0;
1667        for blk in dudt.elmts() {
1668            if let Some(name) = blk.name() {
1669                let ptr = self.variables.get("dudt").unwrap();
1670                let i = self
1671                    .context
1672                    .i32_type()
1673                    .const_int(data_index.try_into().unwrap(), false);
1674                let alloca = Self::get_ptr_to_index(
1675                    &self.create_entry_block_builder(),
1676                    self.real_type,
1677                    ptr,
1678                    i,
1679                    blk.name().unwrap(),
1680                );
1681                self.variables.insert(name.to_owned(), alloca);
1682            }
1683            data_index += blk.nnz();
1684        }
1685    }
1686    fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1687        let var_name = if is_constant { "constants" } else { "data" };
1688        let ptr = *self.variables.get(var_name).unwrap();
1689        let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1690        let i = self
1691            .context
1692            .i32_type()
1693            .const_int(data_index.try_into().unwrap(), false);
1694        let alloca = Self::get_ptr_to_index(
1695            &self.create_entry_block_builder(),
1696            self.real_type,
1697            &ptr,
1698            i,
1699            tensor.name(),
1700        );
1701        self.variables.insert(tensor.name().to_owned(), alloca);
1702
1703        //insert any named blocks
1704        for blk in tensor.elmts() {
1705            if let Some(name) = blk.name() {
1706                let i = self
1707                    .context
1708                    .i32_type()
1709                    .const_int(data_index.try_into().unwrap(), false);
1710                let alloca = Self::get_ptr_to_index(
1711                    &self.create_entry_block_builder(),
1712                    self.real_type,
1713                    &ptr,
1714                    i,
1715                    name,
1716                );
1717                self.variables.insert(name.to_owned(), alloca);
1718            }
1719            // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index
1720            data_index += blk.nnz();
1721        }
1722    }
1723    fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1724        self.variables.get(name).unwrap()
1725    }
1726
1727    fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1728        self.variables.get(tensor.name()).unwrap()
1729    }
1730
1731    fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1732        // Check cache for function
1733        if let Some(&func) = self.functions.get(name) {
1734            return Some(func);
1735        }
1736        let current_block = self.builder.get_insert_block().unwrap();
1737        let current_loc = self.builder.get_current_debug_location().unwrap();
1738        let current_function = self.fn_value_opt.unwrap();
1739
1740        let function = match name {
1741            // support some llvm intrinsics
1742            "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" | "copysign"
1743            | "pow" | "min" | "max" => {
1744                let intrinsic_name = match name {
1745                    "min" => "minnum",
1746                    "max" => "maxnum",
1747                    "abs" => "fabs",
1748                    _ => name,
1749                };
1750                let llvm_name =
1751                    format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
1752
1753                let args_types: Vec<BasicTypeEnum> = vec![self.real_type.into()];
1754                self.builder.unset_current_debug_location();
1755
1756                // Try intrinsic first, fall back to libm for all functions
1757                Some(match Intrinsic::find(&llvm_name) {
1758                    Some(intrinsic) => intrinsic.get_declaration(&self.module, &args_types)?,
1759                    None => {
1760                        // Fallback: declare external libm function
1761                        let args_types_meta: Vec<BasicMetadataTypeEnum> =
1762                            vec![self.real_type.into()];
1763                        let function_type = self.real_type.fn_type(&args_types_meta, false);
1764                        self.module
1765                            .add_function(name, function_type, Some(Linkage::External))
1766                    }
1767                })
1768            }
1769            // some custom functions
1770            "sigmoid" => {
1771                let arg_len = 1;
1772                let ret_type = self.real_type;
1773
1774                let args_types = std::iter::repeat_n(ret_type, arg_len)
1775                    .map(|f| f.into())
1776                    .collect::<Vec<BasicMetadataTypeEnum>>();
1777
1778                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1779
1780                for arg in fn_val.get_param_iter() {
1781                    arg.into_float_value().set_name("x");
1782                }
1783                let _basic_block = self.start_function(fn_val, None);
1784
1785                let x = fn_val.get_nth_param(0)?.into_float_value();
1786                let one = self.real_type.const_float(1.0);
1787                let negx = self.builder.build_float_neg(x, name).ok()?;
1788                let exp = self.get_function("exp").unwrap();
1789                let exp_negx = self
1790                    .builder
1791                    .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1792                    .ok()?;
1793                let one_plus_exp_negx = self
1794                    .builder
1795                    .build_float_add(
1796                        exp_negx
1797                            .try_as_basic_value()
1798                            .unwrap_basic()
1799                            .into_float_value(),
1800                        one,
1801                        name,
1802                    )
1803                    .ok()?;
1804                let sigmoid = self
1805                    .builder
1806                    .build_float_div(one, one_plus_exp_negx, name)
1807                    .ok()?;
1808                self.builder.build_return(Some(&sigmoid)).ok();
1809                Some(fn_val)
1810            }
1811            "arcsinh" | "arccosh" => {
1812                let arg_len = 1;
1813                let ret_type = self.real_type;
1814
1815                let args_types = std::iter::repeat_n(ret_type, arg_len)
1816                    .map(|f| f.into())
1817                    .collect::<Vec<BasicMetadataTypeEnum>>();
1818
1819                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1820
1821                for arg in fn_val.get_param_iter() {
1822                    arg.into_float_value().set_name("x");
1823                }
1824
1825                let _basic_block = self.start_function(fn_val, None);
1826                let x = fn_val.get_nth_param(0)?.into_float_value();
1827                let one = match name {
1828                    "arccosh" => self.real_type.const_float(-1.0),
1829                    "arcsinh" => self.real_type.const_float(1.0),
1830                    _ => panic!("unknown function"),
1831                };
1832                let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1833                let one_plus_x_squared = self.builder.build_float_add(x_squared, one, name).ok()?;
1834                let sqrt = self.get_function("sqrt").unwrap();
1835                let sqrt_one_plus_x_squared = self
1836                    .builder
1837                    .build_call(
1838                        sqrt,
1839                        &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1840                        name,
1841                    )
1842                    .unwrap()
1843                    .try_as_basic_value()
1844                    .unwrap_basic()
1845                    .into_float_value();
1846                let x_plus_sqrt_one_plus_x_squared = self
1847                    .builder
1848                    .build_float_add(x, sqrt_one_plus_x_squared, name)
1849                    .ok()?;
1850                let ln = self.get_function("log").unwrap();
1851                let result = self
1852                    .builder
1853                    .build_call(
1854                        ln,
1855                        &[BasicMetadataValueEnum::FloatValue(
1856                            x_plus_sqrt_one_plus_x_squared,
1857                        )],
1858                        name,
1859                    )
1860                    .unwrap()
1861                    .try_as_basic_value()
1862                    .unwrap_basic()
1863                    .into_float_value();
1864                self.builder.build_return(Some(&result)).ok();
1865                Some(fn_val)
1866            }
1867            "heaviside" => {
1868                let arg_len = 1;
1869                let ret_type = self.real_type;
1870
1871                let args_types = std::iter::repeat_n(ret_type, arg_len)
1872                    .map(|f| f.into())
1873                    .collect::<Vec<BasicMetadataTypeEnum>>();
1874
1875                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1876
1877                for arg in fn_val.get_param_iter() {
1878                    arg.into_float_value().set_name("x");
1879                }
1880
1881                let _basic_block = self.start_function(fn_val, None);
1882                let x = fn_val.get_nth_param(0)?.into_float_value();
1883                let zero = self.real_type.const_float(0.0);
1884                let one = self.real_type.const_float(1.0);
1885                let result = self
1886                    .builder
1887                    .build_select(
1888                        self.builder
1889                            .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1890                            .unwrap(),
1891                        one,
1892                        zero,
1893                        name,
1894                    )
1895                    .ok()?;
1896                self.builder.build_return(Some(&result)).ok();
1897                Some(fn_val)
1898            }
1899            "tanh" | "sinh" | "cosh" => {
1900                let arg_len = 1;
1901                let ret_type = self.real_type;
1902
1903                let args_types = std::iter::repeat_n(ret_type, arg_len)
1904                    .map(|f| f.into())
1905                    .collect::<Vec<BasicMetadataTypeEnum>>();
1906
1907                let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1908
1909                for arg in fn_val.get_param_iter() {
1910                    arg.into_float_value().set_name("x");
1911                }
1912
1913                let _basic_block = self.start_function(fn_val, None);
1914                let x = fn_val.get_nth_param(0)?.into_float_value();
1915                let negx = self.builder.build_float_neg(x, name).ok()?;
1916                let exp = self.get_function("exp").unwrap();
1917                let exp_negx = self
1918                    .builder
1919                    .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1920                    .ok()?;
1921                let expx = self
1922                    .builder
1923                    .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1924                    .ok()?;
1925                let expx_minus_exp_negx = self
1926                    .builder
1927                    .build_float_sub(
1928                        expx.try_as_basic_value().unwrap_basic().into_float_value(),
1929                        exp_negx
1930                            .try_as_basic_value()
1931                            .unwrap_basic()
1932                            .into_float_value(),
1933                        name,
1934                    )
1935                    .ok()?;
1936                let expx_plus_exp_negx = self
1937                    .builder
1938                    .build_float_add(
1939                        expx.try_as_basic_value().unwrap_basic().into_float_value(),
1940                        exp_negx
1941                            .try_as_basic_value()
1942                            .unwrap_basic()
1943                            .into_float_value(),
1944                        name,
1945                    )
1946                    .ok()?;
1947                let result = match name {
1948                    "tanh" => self
1949                        .builder
1950                        .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1951                        .ok()?,
1952                    "sinh" => self
1953                        .builder
1954                        .build_float_div(expx_minus_exp_negx, self.real_type.const_float(2.0), name)
1955                        .ok()?,
1956                    "cosh" => self
1957                        .builder
1958                        .build_float_div(expx_plus_exp_negx, self.real_type.const_float(2.0), name)
1959                        .ok()?,
1960                    _ => panic!("unknown function"),
1961                };
1962                self.builder.build_return(Some(&result)).ok();
1963                Some(fn_val)
1964            }
1965            _ => None,
1966        }?;
1967
1968        if !function.verify(true) {
1969            function.print_to_stderr();
1970            panic!("Invalid generated function for {}", name);
1971        }
1972
1973        self.builder.position_at_end(current_block);
1974        if self.dibuilder.is_some() {
1975            self.builder.set_current_debug_location(current_loc);
1976        }
1977        self.fn_value_opt = Some(current_function);
1978
1979        self.functions.insert(name.to_owned(), function);
1980        Some(function)
1981    }
1982
1983    /// Returns the `FunctionValue` representing the function being compiled.
1984    #[inline]
1985    fn fn_value(&self) -> FunctionValue<'ctx> {
1986        self.fn_value_opt.unwrap()
1987    }
1988
1989    #[inline]
1990    fn tensor_ptr(&self) -> PointerValue<'ctx> {
1991        self.tensor_ptr_opt.unwrap()
1992    }
1993
1994    /// Creates a new builder in the entry block of the function.
1995    fn create_entry_block_builder(&self) -> Builder<'ctx> {
1996        let builder = self.context.create_builder();
1997        let entry = self.fn_value().get_first_basic_block().unwrap();
1998        match entry.get_first_instruction() {
1999            Some(first_instr) => builder.position_before(&first_instr),
2000            None => builder.position_at_end(entry),
2001        }
2002        builder
2003    }
2004
2005    fn jit_compile_scalar(
2006        &mut self,
2007        a: &Tensor,
2008        res_ptr_opt: Option<PointerValue<'ctx>>,
2009    ) -> Result<PointerValue<'ctx>> {
2010        let res_type = self.real_type;
2011        let res_ptr = match res_ptr_opt {
2012            Some(ptr) => ptr,
2013            None => self
2014                .create_entry_block_builder()
2015                .build_alloca(res_type, a.name())?,
2016        };
2017        let name = a.name();
2018        let elmt = a.elmts().first().unwrap();
2019
2020        // if threaded then only the first thread will evaluate the scalar
2021        let curr_block = self.builder.get_insert_block().unwrap();
2022        let mut next_block_opt = None;
2023        if self.threaded {
2024            let next_block = self.context.append_basic_block(self.fn_value(), "next");
2025            self.builder.position_at_end(next_block);
2026            next_block_opt = Some(next_block);
2027        }
2028
2029        let zero = self.int_type.const_zero();
2030        let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
2031        self.builder.build_store(res_ptr, float_value)?;
2032
2033        // complete the threading block
2034        if self.threaded {
2035            let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
2036            self.builder.build_unconditional_branch(exit_block)?;
2037            self.builder.position_at_end(curr_block);
2038
2039            let thread_id = self.get_param("thread_id");
2040            let thread_id = self
2041                .builder
2042                .build_load(self.int_type, *thread_id, "thread_id")
2043                .unwrap()
2044                .into_int_value();
2045            let is_first_thread = self.builder.build_int_compare(
2046                IntPredicate::EQ,
2047                thread_id,
2048                self.int_type.const_zero(),
2049                "is_first_thread",
2050            )?;
2051            self.builder.build_conditional_branch(
2052                is_first_thread,
2053                next_block_opt.unwrap(),
2054                exit_block,
2055            )?;
2056
2057            self.builder.position_at_end(exit_block);
2058        }
2059
2060        Ok(res_ptr)
2061    }
2062
2063    fn jit_compile_tensor(
2064        &mut self,
2065        a: &Tensor,
2066        res_ptr_opt: Option<PointerValue<'ctx>>,
2067        code: Option<&str>,
2068    ) -> Result<PointerValue<'ctx>> {
2069        // treat scalar as a special case (only if not a contraction)
2070        if a.rank() == 0
2071            && a.elmts()
2072                .first()
2073                .is_none_or(|b| b.expr_layout().rank() == 0)
2074        {
2075            return self.jit_compile_scalar(a, res_ptr_opt);
2076        }
2077
2078        let res_type = self.real_type;
2079        let res_ptr = match res_ptr_opt {
2080            Some(ptr) => ptr,
2081            None => self
2082                .create_entry_block_builder()
2083                .build_alloca(res_type, a.name())?,
2084        };
2085
2086        // set up the tensor storage pointer and index into this data
2087        self.tensor_ptr_opt = Some(res_ptr);
2088
2089        for (i, blk) in a.elmts().iter().enumerate() {
2090            if blk.has_values() {
2091                continue;
2092            }
2093            let default = format!("{}-{}", a.name(), i);
2094            let name = blk.name().unwrap_or(default.as_str());
2095            self.jit_compile_block(name, a, blk, code)?;
2096        }
2097        Ok(res_ptr)
2098    }
2099
2100    fn jit_compile_block(
2101        &mut self,
2102        name: &str,
2103        tensor: &Tensor,
2104        elmt: &TensorBlock,
2105        code: Option<&str>,
2106    ) -> Result<()> {
2107        let translation = Translation::new(
2108            elmt.expr_layout(),
2109            elmt.layout(),
2110            elmt.start(),
2111            tensor.layout_ptr(),
2112        );
2113
2114        if let Some(dibuilder) = &self.dibuilder {
2115            let (line_no, column_no) = match (elmt.expr().span, code) {
2116                (Some(s), Some(c)) => {
2117                    let s = Span::new(c, s.pos_start, s.pos_end).unwrap();
2118                    s.start_pos().line_col()
2119                }
2120                _ => (0, 0),
2121            };
2122            let scope = self
2123                .fn_value_opt
2124                .unwrap()
2125                .get_subprogram()
2126                .unwrap()
2127                .as_debug_info_scope();
2128            let loc = dibuilder.create_debug_location(
2129                self.context,
2130                line_no.try_into().unwrap(),
2131                column_no.try_into().unwrap(),
2132                scope,
2133                None,
2134            );
2135            self.builder.set_current_debug_location(loc);
2136        }
2137
2138        // Diagonal 1D->0D contraction produces DenseContraction (since contract_last_axis
2139        // turns it into a Dense scalar); route to dense path for accumulation support.
2140        if elmt.expr_layout().is_dense()
2141            || matches!(translation.source, TranslationFrom::DenseContraction { .. })
2142        {
2143            self.jit_compile_dense_block(name, elmt, &translation)
2144        } else if elmt.expr_layout().is_diagonal() {
2145            self.jit_compile_diagonal_block(name, elmt, &translation)
2146        } else if elmt.expr_layout().is_sparse() {
2147            match translation.source {
2148                TranslationFrom::SparseContraction { .. } => {
2149                    self.jit_compile_sparse_contraction_block(name, elmt, &translation)
2150                }
2151                _ => self.jit_compile_sparse_block(name, elmt, &translation),
2152            }
2153        } else {
2154            Err(anyhow!(
2155                "unsupported block layout: {:?}",
2156                elmt.expr_layout()
2157            ))
2158        }
2159    }
2160
2161    // for dense blocks we can loop through the nested loops to calculate the index, then we compile the expression passing in this index
2162    fn jit_compile_dense_block(
2163        &mut self,
2164        name: &str,
2165        elmt: &TensorBlock,
2166        translation: &Translation,
2167    ) -> Result<()> {
2168        let int_type = self.int_type;
2169
2170        let mut preblock = self.builder.get_insert_block().unwrap();
2171        let expr_rank = elmt.expr_layout().rank();
2172        let expr_shape = elmt
2173            .expr_layout()
2174            .shape()
2175            .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
2176        let one = int_type.const_int(1, false);
2177
2178        let mut expr_strides = vec![1; expr_rank];
2179        if expr_rank > 0 {
2180            for i in (0..expr_rank - 1).rev() {
2181                expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
2182            }
2183        }
2184        let expr_strides = expr_strides
2185            .iter()
2186            .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
2187            .collect::<Vec<IntValue>>();
2188
2189        // setup indices, loop through the nested loops
2190        let mut indices = Vec::new();
2191        let mut blocks = Vec::new();
2192
2193        // allocate the contract sum if needed
2194        let (contract_sum, contract_by, contract_strides) =
2195            if let TranslationFrom::DenseContraction {
2196                contract_by,
2197                contract_len: _,
2198            } = translation.source
2199            {
2200                let contract_rank = expr_rank - contract_by;
2201                let mut contract_strides = vec![1; contract_rank];
2202                for i in (0..contract_rank.saturating_sub(1)).rev() {
2203                    contract_strides[i] =
2204                        contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
2205                }
2206                let contract_strides = contract_strides
2207                    .iter()
2208                    .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
2209                    .collect::<Vec<IntValue>>();
2210                (
2211                    Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
2212                    contract_by,
2213                    Some(contract_strides),
2214                )
2215            } else {
2216                (None, 0, None)
2217            };
2218
2219        // we will thread the output loop, except if we are contracting to a scalar
2220        let (thread_start, thread_end, test_done, next) = if self.threaded {
2221            let (start, end, test_done, next) =
2222                self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
2223            preblock = next;
2224            (Some(start), Some(end), Some(test_done), Some(next))
2225        } else {
2226            (None, None, None, None)
2227        };
2228
2229        // if all axes are contracted, initialize the sum before the loops
2230        if contract_by == expr_rank {
2231            if let Some(contract_sum) = contract_sum {
2232                self.builder
2233                    .build_store(contract_sum, self.real_type.const_zero())?;
2234            }
2235        }
2236
2237        for i in 0..expr_rank {
2238            let block = self.context.append_basic_block(self.fn_value(), name);
2239            self.builder.build_unconditional_branch(block)?;
2240            self.builder.position_at_end(block);
2241
2242            let start_index = if i == 0 && self.threaded {
2243                thread_start.unwrap()
2244            } else {
2245                self.int_type.const_zero()
2246            };
2247
2248            let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
2249            curr_index.add_incoming(&[(&start_index, preblock)]);
2250
2251            let init_point = (expr_rank - contract_by).saturating_sub(1);
2252            if i == init_point && contract_by != expr_rank {
2253                if let Some(contract_sum) = contract_sum {
2254                    self.builder
2255                        .build_store(contract_sum, self.real_type.const_zero())?;
2256                }
2257            }
2258
2259            indices.push(curr_index);
2260            blocks.push(block);
2261            preblock = block;
2262        }
2263
2264        let indices_int: Vec<IntValue> = indices
2265            .iter()
2266            .map(|i| i.as_basic_value().into_int_value())
2267            .collect();
2268
2269        // if indices = (i, j, k) and shape = (a, b, c) calculate expr_index = (k + j*b + i*b*c)
2270        let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
2271        let mut stride = 1u64;
2272        if !indices.is_empty() {
2273            for i in (0..indices.len() - 1).rev() {
2274                let iname_i = indices_int[i];
2275                let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
2276                stride *= shapei;
2277                let stride_intval = self.context.i32_type().const_int(stride, false);
2278                let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
2279                expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
2280            }
2281        }
2282
2283        let float_value =
2284            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2285
2286        if let Some(contract_sum) = contract_sum {
2287            let contract_sum_value = self
2288                .build_load(self.real_type, contract_sum, "contract_sum")?
2289                .into_float_value();
2290            let new_contract_sum_value = self.builder.build_float_add(
2291                contract_sum_value,
2292                float_value,
2293                "new_contract_sum",
2294            )?;
2295            self.builder
2296                .build_store(contract_sum, new_contract_sum_value)?;
2297        } else {
2298            let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
2299                self.int_type.const_zero(),
2300                |acc, (i, s)| {
2301                    let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
2302                    self.builder.build_int_add(acc, tmp, "acc").unwrap()
2303                },
2304            );
2305            self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
2306        }
2307
2308        let mut postblock = self.builder.get_insert_block().unwrap();
2309
2310        // unwind the nested loops
2311        for i in (0..expr_rank).rev() {
2312            // increment index
2313            let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
2314            indices[i].add_incoming(&[(&next_index, postblock)]);
2315            let store_point = (expr_rank - contract_by).saturating_sub(1);
2316            if i == store_point {
2317                if let Some(contract_sum) = contract_sum {
2318                    let contract_sum_value = self
2319                        .build_load(self.real_type, contract_sum, "contract_sum")?
2320                        .into_float_value();
2321                    let contract_strides = contract_strides.as_ref().unwrap();
2322                    let elmt_index = indices_int
2323                        .iter()
2324                        .take(contract_strides.len())
2325                        .zip(contract_strides.iter())
2326                        .fold(self.int_type.const_zero(), |acc, (i, s)| {
2327                            let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
2328                            self.builder.build_int_add(acc, tmp, "acc").unwrap()
2329                        });
2330                    self.jit_compile_store(
2331                        name,
2332                        elmt,
2333                        elmt_index,
2334                        contract_sum_value,
2335                        translation,
2336                    )?;
2337                }
2338            }
2339
2340            let end_index = if i == 0 && self.threaded {
2341                thread_end.unwrap()
2342            } else {
2343                expr_shape[i]
2344            };
2345
2346            // loop condition
2347            let loop_while =
2348                self.builder
2349                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
2350            let block = self.context.append_basic_block(self.fn_value(), name);
2351            self.builder
2352                .build_conditional_branch(loop_while, blocks[i], block)?;
2353            self.builder.position_at_end(block);
2354            postblock = block;
2355        }
2356
2357        if self.threaded {
2358            self.jit_end_threading(
2359                thread_start.unwrap(),
2360                thread_end.unwrap(),
2361                test_done.unwrap(),
2362                next.unwrap(),
2363            )?;
2364        }
2365        Ok(())
2366    }
2367
2368    fn jit_compile_sparse_contraction_block(
2369        &mut self,
2370        name: &str,
2371        elmt: &TensorBlock,
2372        translation: &Translation,
2373    ) -> Result<()> {
2374        match translation.source {
2375            TranslationFrom::SparseContraction { .. } => {}
2376            _ => {
2377                panic!("expected sparse contraction")
2378            }
2379        }
2380        let int_type = self.int_type;
2381
2382        let translation_index = self
2383            .layout
2384            .get_translation_index(elmt.expr_layout(), elmt.layout())
2385            .unwrap();
2386        let translation_index = translation_index + translation.get_from_index_in_data_layout();
2387
2388        let final_contract_index =
2389            int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
2390        let (thread_start, thread_end, test_done, next) = if self.threaded {
2391            let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
2392            (Some(start), Some(end), Some(test_done), Some(next))
2393        } else {
2394            (None, None, None, None)
2395        };
2396
2397        let preblock = self.builder.get_insert_block().unwrap();
2398        let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
2399
2400        // loop through each contraction
2401        let block = self.context.append_basic_block(self.fn_value(), name);
2402        self.builder.build_unconditional_branch(block)?;
2403        self.builder.position_at_end(block);
2404
2405        let contract_index = self.builder.build_phi(int_type, "i")?;
2406        let contract_start = if self.threaded {
2407            thread_start.unwrap()
2408        } else {
2409            int_type.const_zero()
2410        };
2411        contract_index.add_incoming(&[(&contract_start, preblock)]);
2412
2413        let start_index = self.builder.build_int_add(
2414            int_type.const_int(translation_index.try_into().unwrap(), false),
2415            self.builder.build_int_mul(
2416                int_type.const_int(2, false),
2417                contract_index.as_basic_value().into_int_value(),
2418                name,
2419            )?,
2420            name,
2421        )?;
2422        let end_index =
2423            self.builder
2424                .build_int_add(start_index, int_type.const_int(1, false), name)?;
2425        let start_ptr = self.build_gep(
2426            self.int_type,
2427            *self.get_param("indices"),
2428            &[start_index],
2429            "start_index_ptr",
2430        )?;
2431        let start_contract = self
2432            .build_load(self.int_type, start_ptr, "start")?
2433            .into_int_value();
2434        let end_ptr = self.build_gep(
2435            self.int_type,
2436            *self.get_param("indices"),
2437            &[end_index],
2438            "end_index_ptr",
2439        )?;
2440        let end_contract = self
2441            .build_load(self.int_type, end_ptr, "end")?
2442            .into_int_value();
2443
2444        // initialise the contract sum
2445        self.builder
2446            .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
2447
2448        // loop through each element in the contraction
2449        let start_contract_block = self
2450            .context
2451            .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
2452        self.builder
2453            .build_unconditional_branch(start_contract_block)?;
2454        self.builder.position_at_end(start_contract_block);
2455
2456        let expr_index_phi = self.builder.build_phi(int_type, "j")?;
2457        expr_index_phi.add_incoming(&[(&start_contract, block)]);
2458
2459        let expr_index = expr_index_phi.as_basic_value().into_int_value();
2460        let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
2461
2462        // loop body - eval expression and increment sum
2463        let float_value =
2464            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2465        let contract_sum_value = self
2466            .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
2467            .into_float_value();
2468        let new_contract_sum_value =
2469            self.builder
2470                .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
2471        self.builder
2472            .build_store(contract_sum_ptr, new_contract_sum_value)?;
2473
2474        let end_contract_block = self.builder.get_insert_block().unwrap();
2475
2476        // increment contract loop index
2477        let next_elmt_index =
2478            self.builder
2479                .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2480        expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
2481
2482        // contract loop condition
2483        let loop_while = self.builder.build_int_compare(
2484            IntPredicate::ULT,
2485            next_elmt_index,
2486            end_contract,
2487            name,
2488        )?;
2489        let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2490        self.builder.build_conditional_branch(
2491            loop_while,
2492            start_contract_block,
2493            post_contract_block,
2494        )?;
2495        self.builder.position_at_end(post_contract_block);
2496
2497        // store the result
2498        self.jit_compile_store(
2499            name,
2500            elmt,
2501            contract_index.as_basic_value().into_int_value(),
2502            new_contract_sum_value,
2503            translation,
2504        )?;
2505
2506        // increment outer loop index
2507        let next_contract_index = self.builder.build_int_add(
2508            contract_index.as_basic_value().into_int_value(),
2509            int_type.const_int(1, false),
2510            name,
2511        )?;
2512        contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2513
2514        // outer loop condition
2515        let loop_while = self.builder.build_int_compare(
2516            IntPredicate::ULT,
2517            next_contract_index,
2518            thread_end.unwrap_or(final_contract_index),
2519            name,
2520        )?;
2521        let post_block = self.context.append_basic_block(self.fn_value(), name);
2522        self.builder
2523            .build_conditional_branch(loop_while, block, post_block)?;
2524        self.builder.position_at_end(post_block);
2525
2526        if self.threaded {
2527            self.jit_end_threading(
2528                thread_start.unwrap(),
2529                thread_end.unwrap(),
2530                test_done.unwrap(),
2531                next.unwrap(),
2532            )?;
2533        }
2534
2535        Ok(())
2536    }
2537
2538    fn expr_indices_from_elmt_index(
2539        &mut self,
2540        elmt_index: IntValue<'ctx>,
2541        elmt: &TensorBlock,
2542        name: &str,
2543    ) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
2544        let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2545        let int_type = self.int_type;
2546        // loop body - load index from layout
2547        let elmt_index_mult_rank = self.builder.build_int_mul(
2548            elmt_index,
2549            int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2550            name,
2551        )?;
2552        (0..elmt.expr_layout().rank())
2553            .map(|i| {
2554                let layout_index_plus_offset =
2555                    int_type.const_int((layout_index + i).try_into().unwrap(), false);
2556                let curr_index = self.builder.build_int_add(
2557                    elmt_index_mult_rank,
2558                    layout_index_plus_offset,
2559                    name,
2560                )?;
2561                let ptr = Self::get_ptr_to_index(
2562                    &self.builder,
2563                    self.int_type,
2564                    self.get_param("indices"),
2565                    curr_index,
2566                    name,
2567                );
2568                Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2569            })
2570            .collect::<Result<Vec<_>, anyhow::Error>>()
2571    }
2572
2573    // for sparse blocks we can loop through the non-zero elements and extract the index from the layout, then we compile the expression passing in this index
2574    // TODO: havn't implemented contractions yet
2575    fn jit_compile_sparse_block(
2576        &mut self,
2577        name: &str,
2578        elmt: &TensorBlock,
2579        translation: &Translation,
2580    ) -> Result<()> {
2581        let int_type = self.int_type;
2582
2583        let start_index = int_type.const_int(0, false);
2584        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2585
2586        let (thread_start, thread_end, test_done, next) = if self.threaded {
2587            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2588            (Some(start), Some(end), Some(test_done), Some(next))
2589        } else {
2590            (None, None, None, None)
2591        };
2592
2593        // loop through the non-zero elements
2594        let preblock = self.builder.get_insert_block().unwrap();
2595        let loop_block = self.context.append_basic_block(self.fn_value(), name);
2596        self.builder.build_unconditional_branch(loop_block)?;
2597        self.builder.position_at_end(loop_block);
2598
2599        let curr_index = self.builder.build_phi(int_type, "i")?;
2600        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2601
2602        let elmt_index = curr_index.as_basic_value().into_int_value();
2603        let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
2604
2605        // loop body - eval expression
2606        let float_value =
2607            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2608
2609        self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2610
2611        // jit_compile_expr or jit_compile_broadcast_and_store may have changed the current block
2612        let end_loop_block = self.builder.get_insert_block().unwrap();
2613
2614        // increment loop index
2615        let one = int_type.const_int(1, false);
2616        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2617        curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2618
2619        // loop condition
2620        let loop_while = self.builder.build_int_compare(
2621            IntPredicate::ULT,
2622            next_index,
2623            thread_end.unwrap_or(end_index),
2624            name,
2625        )?;
2626        let post_block = self.context.append_basic_block(self.fn_value(), name);
2627        self.builder
2628            .build_conditional_branch(loop_while, loop_block, post_block)?;
2629        self.builder.position_at_end(post_block);
2630
2631        if self.threaded {
2632            self.jit_end_threading(
2633                thread_start.unwrap(),
2634                thread_end.unwrap(),
2635                test_done.unwrap(),
2636                next.unwrap(),
2637            )?;
2638        }
2639
2640        Ok(())
2641    }
2642
2643    // for diagonal blocks we can loop through the diagonal elements and the index is just the same for each element, then we compile the expression passing in this index
2644    fn jit_compile_diagonal_block(
2645        &mut self,
2646        name: &str,
2647        elmt: &TensorBlock,
2648        translation: &Translation,
2649    ) -> Result<()> {
2650        let int_type = self.int_type;
2651
2652        let start_index = int_type.const_int(0, false);
2653        let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2654
2655        let (thread_start, thread_end, test_done, next) = if self.threaded {
2656            let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2657            (Some(start), Some(end), Some(test_done), Some(next))
2658        } else {
2659            (None, None, None, None)
2660        };
2661
2662        // loop through the non-zero elements
2663        let preblock = self.builder.get_insert_block().unwrap();
2664        let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
2665        self.builder.build_unconditional_branch(start_loop_block)?;
2666        self.builder.position_at_end(start_loop_block);
2667
2668        let curr_index = self.builder.build_phi(int_type, "i")?;
2669        curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2670
2671        // loop body - index is just the same for each element
2672        let elmt_index = curr_index.as_basic_value().into_int_value();
2673        let indices_int: Vec<IntValue> =
2674            (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2675
2676        // loop body - eval expression
2677        let float_value =
2678            self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2679
2680        // loop body - store result
2681        self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2682
2683        let end_loop_block = self.builder.get_insert_block().unwrap();
2684
2685        // increment loop index
2686        let one = int_type.const_int(1, false);
2687        let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2688        curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2689
2690        // loop condition
2691        let loop_while = self.builder.build_int_compare(
2692            IntPredicate::ULT,
2693            next_index,
2694            thread_end.unwrap_or(end_index),
2695            name,
2696        )?;
2697        let post_block = self.context.append_basic_block(self.fn_value(), name);
2698        self.builder
2699            .build_conditional_branch(loop_while, start_loop_block, post_block)?;
2700        self.builder.position_at_end(post_block);
2701
2702        if self.threaded {
2703            self.jit_end_threading(
2704                thread_start.unwrap(),
2705                thread_end.unwrap(),
2706                test_done.unwrap(),
2707                next.unwrap(),
2708            )?;
2709        }
2710
2711        Ok(())
2712    }
2713
2714    fn jit_compile_broadcast_and_store(
2715        &mut self,
2716        name: &str,
2717        elmt: &TensorBlock,
2718        float_value: FloatValue<'ctx>,
2719        expr_index: IntValue<'ctx>,
2720        translation: &Translation,
2721    ) -> Result<()> {
2722        let int_type = self.int_type;
2723        let one = int_type.const_int(1, false);
2724        let zero = int_type.const_int(0, false);
2725        let pre_block = self.builder.get_insert_block().unwrap();
2726        match translation.source {
2727            TranslationFrom::Broadcast {
2728                broadcast_by: _,
2729                broadcast_len,
2730            } => {
2731                let bcast_start_index = zero;
2732                let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2733
2734                // setup loop block
2735                let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2736                self.builder.build_unconditional_branch(bcast_block)?;
2737                self.builder.position_at_end(bcast_block);
2738                let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2739                bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2740
2741                // store value
2742                let store_index = self.builder.build_int_add(
2743                    self.builder
2744                        .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2745                    bcast_index.as_basic_value().into_int_value(),
2746                    "bcast_store_index",
2747                )?;
2748                self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2749
2750                // increment index
2751                let bcast_next_index = self.builder.build_int_add(
2752                    bcast_index.as_basic_value().into_int_value(),
2753                    one,
2754                    name,
2755                )?;
2756                bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2757
2758                // loop condition
2759                let bcast_cond = self.builder.build_int_compare(
2760                    IntPredicate::ULT,
2761                    bcast_next_index,
2762                    bcast_end_index,
2763                    "broadcast_cond",
2764                )?;
2765                let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2766                self.builder
2767                    .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2768                self.builder.position_at_end(post_bcast_block);
2769                Ok(())
2770            }
2771            TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2772                self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2773                Ok(())
2774            }
2775            _ => Err(anyhow!("Invalid translation")),
2776        }
2777    }
2778
2779    fn jit_compile_store(
2780        &mut self,
2781        name: &str,
2782        elmt: &TensorBlock,
2783        store_index: IntValue<'ctx>,
2784        float_value: FloatValue<'ctx>,
2785        translation: &Translation,
2786    ) -> Result<()> {
2787        let int_type = self.int_type;
2788        let res_index = match &translation.target {
2789            TranslationTo::Contiguous { start, end: _ } => {
2790                let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2791                self.builder.build_int_add(start_const, store_index, name)?
2792            }
2793            TranslationTo::Sparse { indices: _ } => {
2794                // load store index from layout
2795                let translate_index = self
2796                    .layout
2797                    .get_translation_index(elmt.expr_layout(), elmt.layout())
2798                    .unwrap();
2799                let translate_store_index =
2800                    translate_index + translation.get_to_index_in_data_layout();
2801                let translate_store_index =
2802                    int_type.const_int(translate_store_index.try_into().unwrap(), false);
2803                let elmt_index_strided = store_index;
2804                let curr_index =
2805                    self.builder
2806                        .build_int_add(elmt_index_strided, translate_store_index, name)?;
2807                let ptr = Self::get_ptr_to_index(
2808                    &self.builder,
2809                    self.int_type,
2810                    self.get_param("indices"),
2811                    curr_index,
2812                    name,
2813                );
2814                self.build_load(self.int_type, ptr, name)?.into_int_value()
2815            }
2816        };
2817
2818        let resi_ptr = Self::get_ptr_to_index(
2819            &self.builder,
2820            self.real_type,
2821            &self.tensor_ptr(),
2822            res_index,
2823            name,
2824        );
2825        self.builder.build_store(resi_ptr, float_value)?;
2826        Ok(())
2827    }
2828
2829    fn jit_compile_expr(
2830        &mut self,
2831        name: &str,
2832        expr: &Ast,
2833        index: &[IntValue<'ctx>],
2834        elmt: &TensorBlock,
2835        expr_index: IntValue<'ctx>,
2836    ) -> Result<FloatValue<'ctx>> {
2837        let name = elmt.name().unwrap_or(name);
2838        match &expr.kind {
2839            AstKind::Binop(binop) => {
2840                let lhs =
2841                    self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2842                let rhs =
2843                    self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2844                match binop.op {
2845                    '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2846                    '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2847                    '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2848                    '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2849                    unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2850                }
2851            }
2852            AstKind::Monop(monop) => {
2853                let child =
2854                    self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2855                match monop.op {
2856                    '-' => Ok(self.builder.build_float_neg(child, name)?),
2857                    unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2858                }
2859            }
2860            AstKind::Call(call) => match self.get_function(call.fn_name) {
2861                Some(function) => {
2862                    let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2863                    for arg in call.args.iter() {
2864                        let arg_val =
2865                            self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2866                        args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2867                    }
2868                    let ret_value = self
2869                        .builder
2870                        .build_call(function, args.as_slice(), name)?
2871                        .try_as_basic_value()
2872                        .unwrap_basic()
2873                        .into_float_value();
2874                    Ok(ret_value)
2875                }
2876                None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2877            },
2878            AstKind::CallArg(arg) => {
2879                self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2880            }
2881            AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2882            AstKind::Name(iname) => {
2883                if iname.name == "N" {
2884                    if iname.is_tangent {
2885                        return Ok(self.real_type.const_float(0.0));
2886                    }
2887                    let model_index = self
2888                        .build_load(self.int_type, *self.get_param("model_index"), "model_index")?
2889                        .into_int_value();
2890                    let n_value = self.builder.build_signed_int_to_float(
2891                        model_index,
2892                        self.real_type,
2893                        "n_as_real",
2894                    )?;
2895                    return Ok(n_value);
2896                }
2897                let ptr = *self.get_param(iname.name);
2898                let layout = self.layout.get_layout(iname.name).unwrap().clone();
2899                let iname_elmt_index = if layout.is_dense() {
2900                    // permute indices based on the index chars of this tensor
2901                    let mut no_transform = true;
2902                    let mut iname_index = Vec::new();
2903                    for (i, c) in iname.indices.iter().enumerate() {
2904                        // find the position index of this index char in the tensor's index chars,
2905                        // if it's not found then it must be a contraction index so is at the end
2906                        let pi = elmt
2907                            .indices()
2908                            .iter()
2909                            .position(|x| x == c)
2910                            .unwrap_or(elmt.indices().len());
2911                        // if we are indexing, add the start indice to index[pi]
2912                        if let Some(indice_ast) = iname.indice.as_ref() {
2913                            let Some(indice) = indice_ast.kind.as_indice() else {
2914                                return Err(anyhow!("invalid index expression '{}'", indice_ast));
2915                            };
2916                            let start_intval =
2917                                self.jit_compile_integer_expr(indice.first.as_ref(), name)?;
2918                            let index_pi = if indice.last.is_some() {
2919                                // range indexing: add loop index to start
2920                                let index_pi = if pi >= index.len() {
2921                                    self.context.i32_type().const_int(0, false)
2922                                } else {
2923                                    index[pi]
2924                                };
2925                                self.builder.build_int_add(index_pi, start_intval, name)?
2926                            } else {
2927                                // single-element indexing: constant lookup
2928                                start_intval
2929                            };
2930                            iname_index.push(index_pi);
2931                        } else {
2932                            let index_pi = if pi >= index.len() {
2933                                self.context.i32_type().const_int(0, false)
2934                            } else {
2935                                index[pi]
2936                            };
2937                            iname_index.push(index_pi);
2938                        }
2939                        no_transform = no_transform && pi == i;
2940                    }
2941                    // broadcasting, if the shape is 1 then the index is always 0
2942                    let iname_index: Vec<IntValue> = iname_index
2943                        .iter()
2944                        .enumerate()
2945                        .map(|(i, &idx)| {
2946                            if layout.shape()[i] == 1 {
2947                                self.context.i32_type().const_int(0, false)
2948                            } else {
2949                                idx
2950                            }
2951                        })
2952                        .collect();
2953                    // calculate the element index using iname_index and the shape of the tensor
2954                    // TODO: can we optimise this by using expr_index, and also including elmt_index?
2955                    if !iname_index.is_empty() {
2956                        // if iname_index is (a, b, c) and shape is s calculate iname_elmt_index = ( c * s[2] + b * s[2] * s[1] + a * s[2] * s[1] * s[0] )
2957                        let mut iname_elmt_index = *iname_index.last().unwrap();
2958                        let mut stride = 1u64;
2959                        for i in (0..iname_index.len() - 1).rev() {
2960                            let iname_i = iname_index[i];
2961                            let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2962                            stride *= shapei;
2963                            let stride_intval = self.context.i32_type().const_int(stride, false);
2964                            let stride_mul_i =
2965                                self.builder.build_int_mul(stride_intval, iname_i, name)?;
2966                            iname_elmt_index =
2967                                self.builder
2968                                    .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2969                        }
2970                        iname_elmt_index
2971                    } else {
2972                        // zero if we are not indexing, otherwise use the start value of indice
2973                        let zero = self.context.i32_type().const_int(0, false);
2974                        zero
2975                    }
2976                } else if layout.is_sparse() || layout.is_diagonal() {
2977                    let expr_layout = elmt.expr_layout();
2978
2979                    if expr_layout != &layout {
2980                        // get correct index from binary layout map, ie. indices[ binary_layout_index + expr_index ]
2981                        // if its a -1 then return a 0
2982                        // ie. expr_index = binary_layout[expr_index]
2983                        //.    if expr_index == -1 then return 0 as the value of the expression
2984                        //.    otherwise load the value at that index
2985                        // we are doing an if statement so I think we need to return early here
2986                        let permutation =
2987                            DataLayout::permutation(elmt, iname.indices.as_slice(), &layout);
2988                        if let Some(base_binary_layout_index) =
2989                            self.layout
2990                                .get_binary_layout_index(&layout, expr_layout, permutation)
2991                        {
2992                            let binary_layout_index = self.builder.build_int_add(
2993                                self.int_type
2994                                    .const_int(base_binary_layout_index.try_into().unwrap(), false),
2995                                expr_index,
2996                                name,
2997                            )?;
2998
2999                            let indices_ptr = Self::get_ptr_to_index(
3000                                &self.builder,
3001                                self.int_type,
3002                                self.get_param("indices"),
3003                                binary_layout_index,
3004                                name,
3005                            );
3006                            let mapped_index = self
3007                                .build_load(self.int_type, indices_ptr, name)?
3008                                .into_int_value();
3009                            let is_less_than_zero = self.builder.build_int_compare(
3010                                IntPredicate::SLT,
3011                                mapped_index,
3012                                self.int_type.const_int(0, true),
3013                                "sparse_index_check",
3014                            )?;
3015                            let is_less_than_zero_block =
3016                                self.context.append_basic_block(self.fn_value(), "lt_zero");
3017                            let not_less_than_zero_block = self
3018                                .context
3019                                .append_basic_block(self.fn_value(), "not_lt_zero");
3020                            let merge_block =
3021                                self.context.append_basic_block(self.fn_value(), "merge");
3022                            self.builder.build_conditional_branch(
3023                                is_less_than_zero,
3024                                is_less_than_zero_block,
3025                                not_less_than_zero_block,
3026                            )?;
3027
3028                            // if mapped index < 0 return 0
3029                            self.builder.position_at_end(is_less_than_zero_block);
3030                            let zero_value = self.real_type.const_float(0.);
3031                            self.builder.build_unconditional_branch(merge_block)?;
3032
3033                            // if mapped index >=0 load value at that index
3034                            self.builder.position_at_end(not_less_than_zero_block);
3035                            let value_ptr = Self::get_ptr_to_index(
3036                                &self.builder,
3037                                self.real_type,
3038                                &ptr,
3039                                mapped_index,
3040                                name,
3041                            );
3042                            let value = self
3043                                .build_load(self.real_type, value_ptr, name)?
3044                                .into_float_value();
3045                            self.builder.build_unconditional_branch(merge_block)?;
3046
3047                            // return value or 0 from if statement
3048                            self.builder.position_at_end(merge_block);
3049                            let if_return_value =
3050                                self.builder.build_phi(self.real_type, "sparse_value")?;
3051                            if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
3052                            if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
3053
3054                            let phi_value = if_return_value.as_basic_value().into_float_value();
3055                            return Ok(phi_value);
3056                        } else {
3057                            expr_index
3058                        }
3059                    } else {
3060                        // we can just use the elmt_index since the layouts are the same
3061                        expr_index
3062                    }
3063                } else {
3064                    panic!("unexpected layout");
3065                };
3066                let value_ptr = Self::get_ptr_to_index(
3067                    &self.builder,
3068                    self.real_type,
3069                    &ptr,
3070                    iname_elmt_index,
3071                    name,
3072                );
3073                Ok(self
3074                    .build_load(self.real_type, value_ptr, name)?
3075                    .into_float_value())
3076            }
3077            AstKind::NamedGradient(name) => {
3078                let name_str = name.to_string();
3079                let ptr = self.get_param(name_str.as_str());
3080                Ok(self
3081                    .build_load(self.real_type, *ptr, name_str.as_str())?
3082                    .into_float_value())
3083            }
3084            AstKind::Index(_) => todo!(),
3085            AstKind::Slice(_) => todo!(),
3086            AstKind::Integer(_) => todo!(),
3087            _ => panic!("unexprected astkind"),
3088        }
3089    }
3090
3091    fn jit_compile_integer_expr(&mut self, expr: &Ast, name: &str) -> Result<IntValue<'ctx>> {
3092        match &expr.kind {
3093            AstKind::Integer(value) => Ok(self.int_type.const_int(*value as u64, true)),
3094            AstKind::Number(value) => {
3095                if value.fract() != 0.0 {
3096                    return Err(anyhow!(
3097                        "non-integer value '{}' in integer expression",
3098                        value
3099                    ));
3100                }
3101                Ok(self.int_type.const_int(*value as u64, true))
3102            }
3103            AstKind::Name(iname) => {
3104                if iname.name == "N" {
3105                    Ok(self
3106                        .build_load(self.int_type, *self.get_param("model_index"), name)?
3107                        .into_int_value())
3108                } else {
3109                    Err(anyhow!(
3110                        "unsupported name '{}' in integer expression",
3111                        iname.name
3112                    ))
3113                }
3114            }
3115            AstKind::Monop(monop) => {
3116                let child = self.jit_compile_integer_expr(monop.child.as_ref(), name)?;
3117                match monop.op {
3118                    '+' => Ok(child),
3119                    '-' => self.builder.build_int_neg(child, name).map_err(Into::into),
3120                    _ => Err(anyhow!("unknown integer unary op '{}'", monop.op)),
3121                }
3122            }
3123            AstKind::Binop(binop) => {
3124                let lhs = self.jit_compile_integer_expr(binop.left.as_ref(), name)?;
3125                let rhs = self.jit_compile_integer_expr(binop.right.as_ref(), name)?;
3126                match binop.op {
3127                    '+' => self
3128                        .builder
3129                        .build_int_add(lhs, rhs, name)
3130                        .map_err(Into::into),
3131                    '-' => self
3132                        .builder
3133                        .build_int_sub(lhs, rhs, name)
3134                        .map_err(Into::into),
3135                    '*' => self
3136                        .builder
3137                        .build_int_mul(lhs, rhs, name)
3138                        .map_err(Into::into),
3139                    '/' => self
3140                        .builder
3141                        .build_int_signed_div(lhs, rhs, name)
3142                        .map_err(Into::into),
3143                    '%' => self
3144                        .builder
3145                        .build_int_signed_rem(lhs, rhs, name)
3146                        .map_err(Into::into),
3147                    _ => Err(anyhow!("unknown integer binary op '{}'", binop.op)),
3148                }
3149            }
3150            _ => Err(anyhow!("unsupported integer expression '{}'", expr)),
3151        }
3152    }
3153
3154    fn clear(&mut self) {
3155        self.variables.clear();
3156        //self.functions.clear();
3157        self.fn_value_opt = None;
3158        self.tensor_ptr_opt = None;
3159    }
3160
3161    fn build_dep_call(
3162        &mut self,
3163        dep_fn: FunctionValue<'ctx>,
3164        call_name: &str,
3165        barrier_start: u64,
3166        total_barriers: u64,
3167    ) -> Result<()> {
3168        let t = self
3169            .build_load(self.real_type, *self.get_param("t"), "t")?
3170            .into_float_value();
3171        let u = *self.get_param("u");
3172        let data = *self.get_param("data");
3173        let thread_id = self
3174            .build_load(self.int_type, *self.get_param("thread_id"), "thread_id")?
3175            .into_int_value();
3176        let thread_dim = self
3177            .build_load(self.int_type, *self.get_param("thread_dim"), "thread_dim")?
3178            .into_int_value();
3179        let barrier_start = self.int_type.const_int(barrier_start, false);
3180        let total_barriers = self.int_type.const_int(total_barriers, false);
3181        self.builder.build_call(
3182            dep_fn,
3183            &[
3184                t.into(),
3185                u.into(),
3186                data.into(),
3187                thread_id.into(),
3188                thread_dim.into(),
3189                barrier_start.into(),
3190                total_barriers.into(),
3191            ],
3192            call_name,
3193        )?;
3194        Ok(())
3195    }
3196
3197    fn ensure_time_dep_fn<'m>(
3198        &mut self,
3199        model: &'m DiscreteModel,
3200        code: Option<&str>,
3201    ) -> Result<FunctionValue<'ctx>> {
3202        if let Some(function) = self.module.get_function("calc_time_dep") {
3203            return Ok(function);
3204        }
3205        self.compile_dep_defns(model, "calc_time_dep", model.time_dep_defns(), code)
3206    }
3207
3208    fn ensure_state_dep_fn<'m>(
3209        &mut self,
3210        model: &'m DiscreteModel,
3211        code: Option<&str>,
3212    ) -> Result<FunctionValue<'ctx>> {
3213        if let Some(function) = self.module.get_function("calc_state_dep") {
3214            return Ok(function);
3215        }
3216        self.compile_dep_defns(model, "calc_state_dep", model.state_dep_defns(), code)
3217    }
3218
3219    fn ensure_state_dep_post_f_fn<'m>(
3220        &mut self,
3221        model: &'m DiscreteModel,
3222        code: Option<&str>,
3223    ) -> Result<FunctionValue<'ctx>> {
3224        if let Some(function) = self.module.get_function("calc_state_dep_post_f") {
3225            return Ok(function);
3226        }
3227        self.compile_dep_defns(
3228            model,
3229            "calc_state_dep_post_f",
3230            model.state_dep_post_f_defns(),
3231            code,
3232        )
3233    }
3234
3235    fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
3236        match arg {
3237            BasicValueEnum::PointerValue(v) => v,
3238            BasicValueEnum::FloatValue(v) => {
3239                let alloca = self
3240                    .create_entry_block_builder()
3241                    .build_alloca(arg.get_type(), name)
3242                    .unwrap();
3243                self.builder.build_store(alloca, v).unwrap();
3244                alloca
3245            }
3246            BasicValueEnum::IntValue(v) => {
3247                let alloca = self
3248                    .create_entry_block_builder()
3249                    .build_alloca(arg.get_type(), name)
3250                    .unwrap();
3251                self.builder.build_store(alloca, v).unwrap();
3252                alloca
3253            }
3254            _ => unreachable!(),
3255        }
3256    }
3257
3258    pub fn compile_set_u0<'m>(
3259        &mut self,
3260        model: &'m DiscreteModel,
3261        code: Option<&str>,
3262    ) -> Result<FunctionValue<'ctx>> {
3263        self.clear();
3264        let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
3265        let function = self.add_function(
3266            "set_u0",
3267            fn_arg_names,
3268            &[
3269                self.real_ptr_type.into(),
3270                self.real_ptr_type.into(),
3271                self.int_type.into(),
3272                self.int_type.into(),
3273            ],
3274            None,
3275            false,
3276        );
3277
3278        // add noalias
3279        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3280        let noalign = self.context.create_enum_attribute(alias_id, 0);
3281        for i in &[0, 1] {
3282            function.add_attribute(AttributeLoc::Param(*i), noalign);
3283        }
3284
3285        let _basic_block = self.start_function(function, code);
3286
3287        for (i, arg) in function.get_param_iter().enumerate() {
3288            let name = fn_arg_names[i];
3289            let alloca = self.function_arg_alloca(name, arg);
3290            self.insert_param(name, alloca);
3291        }
3292
3293        self.insert_data(model);
3294        self.insert_indices();
3295
3296        let mut nbarriers = 0;
3297        let total_barriers = (model.input_dep_defns().len() + 1) as u64;
3298        let total_barriers_val = self.int_type.const_int(total_barriers, false);
3299        #[allow(clippy::explicit_counter_loop)]
3300        for a in model.input_dep_defns() {
3301            self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3302            let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3303            self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3304            nbarriers += 1;
3305        }
3306
3307        self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")), code)?;
3308        let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3309        self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3310
3311        self.builder.build_return(None)?;
3312
3313        if function.verify(true) {
3314            Ok(function)
3315        } else {
3316            function.print_to_stderr();
3317            unsafe {
3318                function.delete();
3319            }
3320            Err(anyhow!("Invalid generated function."))
3321        }
3322    }
3323
3324    pub fn compile_calc_out<'m>(
3325        &mut self,
3326        model: &'m DiscreteModel,
3327        include_constants: bool,
3328        code: Option<&str>,
3329    ) -> Result<FunctionValue<'ctx>> {
3330        let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3331        let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3332        let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
3333        self.clear();
3334        let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
3335        let function_name = if include_constants {
3336            "calc_out_full"
3337        } else {
3338            "calc_out"
3339        };
3340        let function = self.add_function(
3341            function_name,
3342            fn_arg_names,
3343            &[
3344                self.real_type.into(),
3345                self.real_ptr_type.into(),
3346                self.real_ptr_type.into(),
3347                self.real_ptr_type.into(),
3348                self.int_type.into(),
3349                self.int_type.into(),
3350            ],
3351            None,
3352            false,
3353        );
3354
3355        // add noalias
3356        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3357        let noalign = self.context.create_enum_attribute(alias_id, 0);
3358        for i in &[1, 2, 3] {
3359            function.add_attribute(AttributeLoc::Param(*i), noalign);
3360        }
3361
3362        let _basic_block = self.start_function(function, code);
3363
3364        for (i, arg) in function.get_param_iter().enumerate() {
3365            let name = fn_arg_names[i];
3366            let alloca = self.function_arg_alloca(name, arg);
3367            self.insert_param(name, alloca);
3368        }
3369
3370        self.insert_state(model.state());
3371        self.insert_data(model);
3372        self.insert_indices();
3373
3374        // print thread_id and thread_dim
3375        //let thread_id = function.get_nth_param(3).unwrap();
3376        //let thread_dim = function.get_nth_param(4).unwrap();
3377        //self.compile_print_value("thread_id", PrintValue::Int(thread_id.into_int_value()))?;
3378        //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?;
3379        if let Some(out) = model.out() {
3380            let mut nbarriers = 0;
3381            let mut total_barriers = (model.time_dep_defns().len()
3382                + model.state_dep_defns().len()
3383                + model.state_dep_post_f_defns().len()
3384                + 1) as u64;
3385            if include_constants {
3386                total_barriers += model.input_dep_defns().len() as u64;
3387            }
3388            let total_barriers_val = self.int_type.const_int(total_barriers, false);
3389            if include_constants {
3390                // calculate time independant definitions
3391                for tensor in model.input_dep_defns() {
3392                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3393                    let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3394                    self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3395                    nbarriers += 1;
3396                }
3397            }
3398
3399            // calculate time dependant definitions
3400            if !model.time_dep_defns().is_empty() {
3401                self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3402                nbarriers += model.time_dep_defns().len() as u64;
3403            }
3404
3405            // calculate state dependant definitions
3406            if !model.state_dep_defns().is_empty() {
3407                self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3408                nbarriers += model.state_dep_defns().len() as u64;
3409            }
3410            if !model.state_dep_post_f_defns().is_empty() {
3411                self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
3412                nbarriers += model.state_dep_post_f_defns().len() as u64;
3413            }
3414
3415            self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())), code)?;
3416            let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3417            self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3418        }
3419        self.builder.build_return(None)?;
3420
3421        if function.verify(true) {
3422            Ok(function)
3423        } else {
3424            function.print_to_stderr();
3425            unsafe {
3426                function.delete();
3427            }
3428            Err(anyhow!("Invalid generated function."))
3429        }
3430    }
3431
3432    fn compile_dep_defns<'m>(
3433        &mut self,
3434        model: &'m DiscreteModel,
3435        fn_name: &str,
3436        tensors: &[Tensor<'m>],
3437        code: Option<&str>,
3438    ) -> Result<FunctionValue<'ctx>> {
3439        self.clear();
3440        let fn_arg_names = &[
3441            "t",
3442            "u",
3443            "data",
3444            "thread_id",
3445            "thread_dim",
3446            "barrier_start",
3447            "total_barriers",
3448        ];
3449        let function = self.add_function(
3450            fn_name,
3451            fn_arg_names,
3452            &[
3453                self.real_type.into(),
3454                self.real_ptr_type.into(),
3455                self.real_ptr_type.into(),
3456                self.int_type.into(),
3457                self.int_type.into(),
3458                self.int_type.into(),
3459                self.int_type.into(),
3460            ],
3461            None,
3462            false,
3463        );
3464
3465        // add noalias
3466        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3467        let noalign = self.context.create_enum_attribute(alias_id, 0);
3468        for i in &[1, 2] {
3469            function.add_attribute(AttributeLoc::Param(*i), noalign);
3470        }
3471
3472        let _basic_block = self.start_function(function, code);
3473
3474        for (i, arg) in function.get_param_iter().enumerate() {
3475            let name = fn_arg_names[i];
3476            let alloca = self.function_arg_alloca(name, arg);
3477            self.insert_param(name, alloca);
3478        }
3479
3480        self.insert_state(model.state());
3481        self.insert_data(model);
3482        self.insert_indices();
3483
3484        let barrier_start = self
3485            .build_load(
3486                self.int_type,
3487                *self.get_param("barrier_start"),
3488                "barrier_start",
3489            )?
3490            .into_int_value();
3491        let total_barriers = self
3492            .build_load(
3493                self.int_type,
3494                *self.get_param("total_barriers"),
3495                "total_barriers",
3496            )?
3497            .into_int_value();
3498        let one = self.int_type.const_int(1, false);
3499
3500        for (index, tensor) in tensors.iter().enumerate() {
3501            self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3502            let index = self.int_type.const_int(index as u64, false);
3503            let barrier_num_base =
3504                self.builder
3505                    .build_int_add(barrier_start, index, "barrier_num_base")?;
3506            let barrier_num = self
3507                .builder
3508                .build_int_add(barrier_num_base, one, "barrier_num")?;
3509            self.jit_compile_call_barrier(barrier_num, total_barriers);
3510        }
3511
3512        self.builder.build_return(None)?;
3513
3514        if function.verify(true) {
3515            Ok(function)
3516        } else {
3517            function.print_to_stderr();
3518            unsafe {
3519                function.delete();
3520            }
3521            Err(anyhow!("Invalid generated function."))
3522        }
3523    }
3524
3525    pub fn compile_calc_stop<'m>(
3526        &mut self,
3527        model: &'m DiscreteModel,
3528        include_constants: bool,
3529        code: Option<&str>,
3530    ) -> Result<FunctionValue<'ctx>> {
3531        let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3532        let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3533        let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
3534        self.clear();
3535        let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
3536        let function_name = if include_constants {
3537            "calc_stop_full"
3538        } else {
3539            "calc_stop"
3540        };
3541        let function = self.add_function(
3542            function_name,
3543            fn_arg_names,
3544            &[
3545                self.real_type.into(),
3546                self.real_ptr_type.into(),
3547                self.real_ptr_type.into(),
3548                self.real_ptr_type.into(),
3549                self.int_type.into(),
3550                self.int_type.into(),
3551            ],
3552            None,
3553            false,
3554        );
3555
3556        // add noalias
3557        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3558        let noalign = self.context.create_enum_attribute(alias_id, 0);
3559        for i in &[1, 2, 3] {
3560            function.add_attribute(AttributeLoc::Param(*i), noalign);
3561        }
3562
3563        let _basic_block = self.start_function(function, code);
3564
3565        for (i, arg) in function.get_param_iter().enumerate() {
3566            let name = fn_arg_names[i];
3567            let alloca = self.function_arg_alloca(name, arg);
3568            self.insert_param(name, alloca);
3569        }
3570
3571        self.insert_state(model.state());
3572        self.insert_data(model);
3573        self.insert_indices();
3574
3575        if let Some(stop) = model.stop() {
3576            // calculate time dependant definitions
3577            let mut nbarriers = 0;
3578            let mut total_barriers = (model.time_dep_defns().len()
3579                + model.state_dep_defns().len()
3580                + model.state_dep_post_f_defns().len()
3581                + 1) as u64;
3582            if include_constants {
3583                total_barriers += model.input_dep_defns().len() as u64;
3584            }
3585            let total_barriers_val = self.int_type.const_int(total_barriers, false);
3586            if include_constants {
3587                // calculate time independent definitions
3588                for tensor in model.input_dep_defns() {
3589                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3590                    let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3591                    self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3592                    nbarriers += 1;
3593                }
3594            }
3595            if !model.time_dep_defns().is_empty() {
3596                self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3597                nbarriers += model.time_dep_defns().len() as u64;
3598            }
3599
3600            // calculate state dependant definitions
3601            if !model.state_dep_defns().is_empty() {
3602                self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3603                nbarriers += model.state_dep_defns().len() as u64;
3604            }
3605            if !model.state_dep_post_f_defns().is_empty() {
3606                self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
3607                nbarriers += model.state_dep_post_f_defns().len() as u64;
3608            }
3609
3610            let res_ptr = self.get_param("root");
3611            self.jit_compile_tensor(stop, Some(*res_ptr), code)?;
3612            let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3613            self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3614        }
3615        self.builder.build_return(None)?;
3616
3617        if function.verify(true) {
3618            Ok(function)
3619        } else {
3620            function.print_to_stderr();
3621            unsafe {
3622                function.delete();
3623            }
3624            Err(anyhow!("Invalid generated function."))
3625        }
3626    }
3627
3628    pub fn compile_reset<'m>(
3629        &mut self,
3630        model: &'m DiscreteModel,
3631        include_constants: bool,
3632        code: Option<&str>,
3633    ) -> Result<FunctionValue<'ctx>> {
3634        let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3635        let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3636        let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
3637        self.clear();
3638        let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"];
3639        let function_name = if include_constants {
3640            "reset_full"
3641        } else {
3642            "reset"
3643        };
3644        let function = self.add_function(
3645            function_name,
3646            fn_arg_names,
3647            &[
3648                self.real_type.into(),
3649                self.real_ptr_type.into(),
3650                self.real_ptr_type.into(),
3651                self.real_ptr_type.into(),
3652                self.int_type.into(),
3653                self.int_type.into(),
3654            ],
3655            None,
3656            false,
3657        );
3658
3659        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3660        let noalign = self.context.create_enum_attribute(alias_id, 0);
3661        for i in &[1, 2, 3] {
3662            function.add_attribute(AttributeLoc::Param(*i), noalign);
3663        }
3664
3665        let _basic_block = self.start_function(function, code);
3666
3667        for (i, arg) in function.get_param_iter().enumerate() {
3668            let name = fn_arg_names[i];
3669            let alloca = self.function_arg_alloca(name, arg);
3670            self.insert_param(name, alloca);
3671        }
3672
3673        self.insert_state(model.state());
3674        self.insert_data(model);
3675        self.insert_indices();
3676
3677        if let Some(reset) = model.reset() {
3678            let mut nbarriers = 0;
3679            let mut total_barriers = (model.time_dep_defns().len()
3680                + model.state_dep_defns().len()
3681                + model.state_dep_post_f_defns().len()
3682                + 1) as u64;
3683            if include_constants {
3684                total_barriers += model.input_dep_defns().len() as u64;
3685            }
3686            let total_barriers_val = self.int_type.const_int(total_barriers, false);
3687            if include_constants {
3688                // calculate time independent definitions
3689                for tensor in model.input_dep_defns() {
3690                    self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3691                    let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3692                    self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3693                    nbarriers += 1;
3694                }
3695            }
3696            if !model.time_dep_defns().is_empty() {
3697                self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3698                nbarriers += model.time_dep_defns().len() as u64;
3699            }
3700
3701            if !model.state_dep_defns().is_empty() {
3702                self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3703                nbarriers += model.state_dep_defns().len() as u64;
3704            }
3705            if !model.state_dep_post_f_defns().is_empty() {
3706                self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
3707                nbarriers += model.state_dep_post_f_defns().len() as u64;
3708            }
3709
3710            let res_ptr = self.get_param("reset");
3711            self.jit_compile_tensor(reset, Some(*res_ptr), code)?;
3712            let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3713            self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3714        }
3715        self.builder.build_return(None)?;
3716
3717        if function.verify(true) {
3718            Ok(function)
3719        } else {
3720            function.print_to_stderr();
3721            unsafe {
3722                function.delete();
3723            }
3724            Err(anyhow!("Invalid generated function."))
3725        }
3726    }
3727
3728    pub fn compile_rhs<'m>(
3729        &mut self,
3730        model: &'m DiscreteModel,
3731        include_constants: bool,
3732        code: Option<&str>,
3733    ) -> Result<FunctionValue<'ctx>> {
3734        let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3735        let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3736        self.clear();
3737        let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
3738        let function_name = if include_constants { "rhs_full" } else { "rhs" };
3739        let function = self.add_function(
3740            function_name,
3741            fn_arg_names,
3742            &[
3743                self.real_type.into(),
3744                self.real_ptr_type.into(),
3745                self.real_ptr_type.into(),
3746                self.real_ptr_type.into(),
3747                self.int_type.into(),
3748                self.int_type.into(),
3749            ],
3750            None,
3751            false,
3752        );
3753
3754        // add noalias
3755        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3756        let noalign = self.context.create_enum_attribute(alias_id, 0);
3757        for i in &[1, 2, 3] {
3758            function.add_attribute(AttributeLoc::Param(*i), noalign);
3759        }
3760
3761        let _basic_block = self.start_function(function, code);
3762
3763        for (i, arg) in function.get_param_iter().enumerate() {
3764            let name = fn_arg_names[i];
3765            let alloca = self.function_arg_alloca(name, arg);
3766            self.insert_param(name, alloca);
3767        }
3768
3769        self.insert_state(model.state());
3770        self.insert_data(model);
3771        self.insert_indices();
3772
3773        let mut nbarriers = 0;
3774        let mut total_barriers =
3775            (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
3776        if include_constants {
3777            total_barriers += model.input_dep_defns().len() as u64;
3778            // calculate constant definitions
3779            for tensor in model.input_dep_defns() {
3780                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3781                let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3782                let total_barriers_val = self.int_type.const_int(total_barriers, false);
3783                self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3784                nbarriers += 1;
3785            }
3786        }
3787
3788        // calculate time dependant definitions
3789        if !model.time_dep_defns().is_empty() {
3790            self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3791            nbarriers += model.time_dep_defns().len() as u64;
3792        }
3793
3794        if !model.state_dep_defns().is_empty() {
3795            self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3796            nbarriers += model.state_dep_defns().len() as u64;
3797        }
3798
3799        // F
3800        let res_ptr = self.get_param("rr");
3801        self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?;
3802        let total_barriers_val = self.int_type.const_int(total_barriers, false);
3803        let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3804        self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3805
3806        self.builder.build_return(None)?;
3807
3808        if function.verify(true) {
3809            Ok(function)
3810        } else {
3811            function.print_to_stderr();
3812            unsafe {
3813                function.delete();
3814            }
3815            Err(anyhow!("Invalid generated function."))
3816        }
3817    }
3818
3819    pub fn compile_mass<'m>(
3820        &mut self,
3821        model: &'m DiscreteModel,
3822        code: Option<&str>,
3823    ) -> Result<FunctionValue<'ctx>> {
3824        self.clear();
3825        let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
3826        let function = self.add_function(
3827            "mass",
3828            fn_arg_names,
3829            &[
3830                self.real_type.into(),
3831                self.real_ptr_type.into(),
3832                self.real_ptr_type.into(),
3833                self.real_ptr_type.into(),
3834                self.int_type.into(),
3835                self.int_type.into(),
3836            ],
3837            None,
3838            false,
3839        );
3840
3841        // add noalias
3842        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3843        let noalign = self.context.create_enum_attribute(alias_id, 0);
3844        for i in &[1, 2, 3] {
3845            function.add_attribute(AttributeLoc::Param(*i), noalign);
3846        }
3847
3848        let _basic_block = self.start_function(function, code);
3849
3850        for (i, arg) in function.get_param_iter().enumerate() {
3851            let name = fn_arg_names[i];
3852            let alloca = self.function_arg_alloca(name, arg);
3853            self.insert_param(name, alloca);
3854        }
3855
3856        // only put code in this function if we have a state_dot and lhs
3857        if model.state_dot().is_some() && model.lhs().is_some() {
3858            let state_dot = model.state_dot().unwrap();
3859            let lhs = model.lhs().unwrap();
3860
3861            self.insert_dot_state(state_dot);
3862            self.insert_data(model);
3863            self.insert_indices();
3864
3865            // calculate time dependant definitions
3866            let mut nbarriers = 0;
3867            let total_barriers =
3868                (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
3869            let total_barriers_val = self.int_type.const_int(total_barriers, false);
3870            for tensor in model.time_dep_defns() {
3871                self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3872                let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3873                self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3874                nbarriers += 1;
3875            }
3876
3877            for a in model.dstate_dep_defns() {
3878                self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3879                let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3880                self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3881                nbarriers += 1;
3882            }
3883
3884            // mass
3885            let res_ptr = self.get_param("rr");
3886            self.jit_compile_tensor(lhs, Some(*res_ptr), code)?;
3887            let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3888            self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3889        }
3890
3891        self.builder.build_return(None)?;
3892
3893        if function.verify(true) {
3894            Ok(function)
3895        } else {
3896            function.print_to_stderr();
3897            unsafe {
3898                function.delete();
3899            }
3900            Err(anyhow!("Invalid generated function."))
3901        }
3902    }
3903
3904    pub fn compile_gradient(
3905        &mut self,
3906        original_function: FunctionValue<'ctx>,
3907        args_type: &[CompileGradientArgType],
3908        mode: CompileMode,
3909        fn_name: &str,
3910    ) -> Result<FunctionValue<'ctx>> {
3911        self.clear();
3912
3913        // construct the gradient function
3914        let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
3915
3916        let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
3917
3918        let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
3919        let mut start_param_index: Vec<u32> = Vec::new();
3920        let mut ptr_arg_indices: Vec<u32> = Vec::new();
3921        for (i, arg) in original_function.get_param_iter().enumerate() {
3922            let start_index = u32::try_from(fn_type.len()).unwrap();
3923            start_param_index.push(start_index);
3924            let arg_type = arg.get_type();
3925            fn_type.push(arg_type.into());
3926
3927            // constant args with type T in the original funciton have 2 args of type [int, T]
3928            enzyme_fn_type.push(self.int_type.into());
3929            enzyme_fn_type.push(arg.get_type().into());
3930
3931            if arg_type.is_pointer_type() {
3932                ptr_arg_indices.push(start_index);
3933            }
3934
3935            match args_type[i] {
3936                CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
3937                    fn_type.push(arg.get_type().into());
3938                    enzyme_fn_type.push(arg.get_type().into());
3939                    if arg_type.is_pointer_type() {
3940                        ptr_arg_indices.push(start_index + 1);
3941                    }
3942                }
3943                CompileGradientArgType::Const => {}
3944            }
3945        }
3946        let fn_arg_names = fn_type
3947            .iter()
3948            .enumerate()
3949            .map(|(i, _)| format!("arg{}", i))
3950            .collect::<Vec<String>>();
3951        let fn_arg_names_ref = fn_arg_names
3952            .iter()
3953            .map(|s| s.as_str())
3954            .collect::<Vec<&str>>();
3955        let function = self.add_function(fn_name, &fn_arg_names_ref, &fn_type, None, false);
3956
3957        // add noalias
3958        let alias_id = Attribute::get_named_enum_kind_id("noalias");
3959        let noalign = self.context.create_enum_attribute(alias_id, 0);
3960        for i in ptr_arg_indices {
3961            function.add_attribute(AttributeLoc::Param(i), noalign);
3962        }
3963
3964        let _basic_block = self.start_function(function, None);
3965
3966        let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
3967        let mut input_activity = Vec::new();
3968        let mut arg_trees = Vec::new();
3969        for (i, arg) in original_function.get_param_iter().enumerate() {
3970            let param_index = start_param_index[i];
3971            let fn_arg = function.get_nth_param(param_index).unwrap();
3972
3973            // we'll probably only get double or pointers to doubles, so let assume this for now
3974            // todo: perhaps refactor this into a recursive function, might be overkill
3975            let concrete_type = match arg.get_type() {
3976                BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
3977                BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
3978                    RealType::F32 => CConcreteType_DT_Float,
3979                    RealType::F64 => CConcreteType_DT_Double,
3980                },
3981                BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
3982                _ => panic!("unsupported type"),
3983            };
3984            let new_tree = unsafe {
3985                EnzymeNewTypeTreeCT(
3986                    concrete_type,
3987                    self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3988                )
3989            };
3990            unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3991
3992            // pointer to real type
3993            if concrete_type == CConcreteType_DT_Pointer {
3994                let inner_concrete_type = match self.diffsl_real_type {
3995                    RealType::F32 => CConcreteType_DT_Float,
3996                    RealType::F64 => CConcreteType_DT_Double,
3997                };
3998                let inner_new_tree = unsafe {
3999                    EnzymeNewTypeTreeCT(
4000                        inner_concrete_type,
4001                        self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
4002                    )
4003                };
4004                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
4005                unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
4006                unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
4007            }
4008
4009            arg_trees.push(new_tree);
4010            match args_type[i] {
4011                CompileGradientArgType::Dup => {
4012                    // pass in the arg value
4013                    enzyme_fn_args.push(fn_arg.into());
4014
4015                    // pass in the darg value
4016                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
4017                    enzyme_fn_args.push(fn_darg.into());
4018
4019                    input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
4020                }
4021                CompileGradientArgType::DupNoNeed => {
4022                    // pass in the arg value
4023                    enzyme_fn_args.push(fn_arg.into());
4024
4025                    // pass in the darg value
4026                    let fn_darg = function.get_nth_param(param_index + 1).unwrap();
4027                    enzyme_fn_args.push(fn_darg.into());
4028
4029                    input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
4030                }
4031                CompileGradientArgType::Const => {
4032                    // pass in the arg value
4033                    enzyme_fn_args.push(fn_arg.into());
4034
4035                    input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
4036                }
4037            }
4038        }
4039        // if we have void ret, this must be false;
4040        let ret_primary_ret = false;
4041        let diff_ret = false;
4042        let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
4043        let ret_tree = unsafe {
4044            EnzymeNewTypeTreeCT(
4045                CConcreteType_DT_Anything,
4046                self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
4047            )
4048        };
4049
4050        // always optimize
4051        let fnc_opt_base = true;
4052        let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
4053
4054        let kv_tmp = IntList {
4055            data: std::ptr::null_mut(),
4056            size: 0,
4057        };
4058        let mut known_values = vec![kv_tmp; input_activity.len()];
4059
4060        let fn_type_info = CFnTypeInfo {
4061            Arguments: arg_trees.as_mut_ptr(),
4062            Return: ret_tree,
4063            KnownValues: known_values.as_mut_ptr(),
4064        };
4065
4066        let type_analysis: EnzymeTypeAnalysisRef =
4067            unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
4068
4069        let mut args_uncacheable = vec![0; arg_trees.len()];
4070
4071        let enzyme_function = match mode {
4072            CompileMode::Forward | CompileMode::ForwardSens => unsafe {
4073                EnzymeCreateForwardDiff(
4074                    logic_ref, // Logic
4075                    std::ptr::null_mut(),
4076                    std::ptr::null_mut(),
4077                    original_function.as_value_ref(),
4078                    ret_activity, // LLVM function, return type
4079                    input_activity.as_mut_ptr(),
4080                    input_activity.len(), // constant arguments
4081                    type_analysis,        // type analysis struct
4082                    ret_primary_ret as u8,
4083                    CDerivativeMode_DEM_ForwardMode, // return value, dret_used, top_level which was 1
4084                    1,                               // free memory
4085                    0,                               // runtime activity
4086                    0,                               // strong zero
4087                    1,                               // vector mode width
4088                    std::ptr::null_mut(),            // additional argument
4089                    fn_type_info,                    // additional_arg, type info (return + args)
4090                    1,                               // subsequent calls may write
4091                    args_uncacheable.as_mut_ptr(),   // overwritten args
4092                    args_uncacheable.len(),          // overwritten args length
4093                    std::ptr::null_mut(),            // write augmented function to this
4094                )
4095            },
4096            CompileMode::Reverse | CompileMode::ReverseSens => {
4097                let mut call_enzyme = || unsafe {
4098                    EnzymeCreatePrimalAndGradient(
4099                        logic_ref,
4100                        std::ptr::null_mut(),
4101                        std::ptr::null_mut(),
4102                        original_function.as_value_ref(),
4103                        ret_activity,
4104                        input_activity.as_mut_ptr(),
4105                        input_activity.len(),
4106                        type_analysis,
4107                        ret_primary_ret as u8,
4108                        diff_ret as u8,
4109                        CDerivativeMode_DEM_ReverseModeCombined,
4110                        0,
4111                        0, // strong zero
4112                        1,
4113                        1,
4114                        std::ptr::null_mut(),
4115                        0, // force annonymous tape
4116                        fn_type_info,
4117                        0, // subsequent calls may write
4118                        args_uncacheable.as_mut_ptr(),
4119                        args_uncacheable.len(),
4120                        std::ptr::null_mut(),
4121                        if self.threaded { 1 } else { 0 }, // atomic add
4122                    )
4123                };
4124                if self.threaded {
4125                    // the register call handler alters a global variable, so we need to lock it
4126                    let _lock = my_mutex.lock().unwrap();
4127                    let barrier_string = CString::new("barrier").unwrap();
4128                    unsafe {
4129                        EnzymeRegisterCallHandler(
4130                            barrier_string.as_ptr(),
4131                            Some(fwd_handler),
4132                            Some(rev_handler),
4133                        )
4134                    };
4135                    let ret = call_enzyme();
4136                    // unregister it so some other thread doesn't use it
4137                    unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
4138                    ret
4139                } else {
4140                    call_enzyme()
4141                }
4142            }
4143        };
4144
4145        // free everything
4146        unsafe { FreeEnzymeLogic(logic_ref) };
4147        unsafe { FreeTypeAnalysis(type_analysis) };
4148        unsafe { EnzymeFreeTypeTree(ret_tree) };
4149        for tree in arg_trees {
4150            unsafe { EnzymeFreeTypeTree(tree) };
4151        }
4152
4153        // call enzyme function
4154        let enzyme_function =
4155            unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
4156        self.builder
4157            .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
4158
4159        // return
4160        self.builder.build_return(None)?;
4161
4162        if function.verify(true) {
4163            Ok(function)
4164        } else {
4165            function.print_to_stderr();
4166            enzyme_function.print_to_stderr();
4167            unsafe {
4168                function.delete();
4169            }
4170            Err(anyhow!("Invalid generated function."))
4171        }
4172    }
4173
4174    pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
4175        self.clear();
4176        let fn_arg_names = &[
4177            "states",
4178            "inputs",
4179            "outputs",
4180            "data",
4181            "stop",
4182            "has_mass",
4183            "has_reset",
4184        ];
4185        let function = self.add_function(
4186            "get_dims",
4187            fn_arg_names,
4188            &[
4189                self.int_ptr_type.into(),
4190                self.int_ptr_type.into(),
4191                self.int_ptr_type.into(),
4192                self.int_ptr_type.into(),
4193                self.int_ptr_type.into(),
4194                self.int_ptr_type.into(),
4195                self.int_ptr_type.into(),
4196            ],
4197            None,
4198            false,
4199        );
4200        let _block = self.start_function(function, None);
4201
4202        for (i, arg) in function.get_param_iter().enumerate() {
4203            let name = fn_arg_names[i];
4204            let alloca = self.function_arg_alloca(name, arg);
4205            self.insert_param(name, alloca);
4206        }
4207
4208        self.insert_indices();
4209
4210        let number_of_states = model.state().nnz() as u64;
4211        let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
4212        let number_of_outputs = match model.out() {
4213            Some(out) => out.nnz() as u64,
4214            None => 0,
4215        };
4216        let number_of_stop = if let Some(stop) = model.stop() {
4217            stop.nnz() as u64
4218        } else {
4219            0
4220        };
4221        let has_mass = match model.lhs().is_some() {
4222            true => 1u64,
4223            false => 0u64,
4224        };
4225        let has_reset = match model.reset().is_some() {
4226            true => 1u64,
4227            false => 0u64,
4228        };
4229        let data_len = self.layout.data().len() as u64;
4230        self.builder.build_store(
4231            *self.get_param("states"),
4232            self.int_type.const_int(number_of_states, false),
4233        )?;
4234        self.builder.build_store(
4235            *self.get_param("inputs"),
4236            self.int_type.const_int(number_of_inputs, false),
4237        )?;
4238        self.builder.build_store(
4239            *self.get_param("outputs"),
4240            self.int_type.const_int(number_of_outputs, false),
4241        )?;
4242        self.builder.build_store(
4243            *self.get_param("data"),
4244            self.int_type.const_int(data_len, false),
4245        )?;
4246        self.builder.build_store(
4247            *self.get_param("stop"),
4248            self.int_type.const_int(number_of_stop, false),
4249        )?;
4250        self.builder.build_store(
4251            *self.get_param("has_mass"),
4252            self.int_type.const_int(has_mass, false),
4253        )?;
4254        self.builder.build_store(
4255            *self.get_param("has_reset"),
4256            self.int_type.const_int(has_reset, false),
4257        )?;
4258        self.builder.build_return(None)?;
4259
4260        if function.verify(true) {
4261            Ok(function)
4262        } else {
4263            function.print_to_stderr();
4264            unsafe {
4265                function.delete();
4266            }
4267            Err(anyhow!("Invalid generated function."))
4268        }
4269    }
4270
4271    pub fn compile_get_tensor(
4272        &mut self,
4273        model: &DiscreteModel,
4274        name: &str,
4275    ) -> Result<FunctionValue<'ctx>> {
4276        self.clear();
4277        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
4278        let function_name = format!("get_tensor_{name}");
4279        let fn_arg_names = &["data", "tensor_data", "tensor_size"];
4280        let function = self.add_function(
4281            function_name.as_str(),
4282            fn_arg_names,
4283            &[
4284                self.real_ptr_type.into(),
4285                real_ptr_ptr_type.into(),
4286                self.int_ptr_type.into(),
4287            ],
4288            None,
4289            false,
4290        );
4291        let _basic_block = self.start_function(function, None);
4292
4293        for (i, arg) in function.get_param_iter().enumerate() {
4294            let name = fn_arg_names[i];
4295            let alloca = self.function_arg_alloca(name, arg);
4296            self.insert_param(name, alloca);
4297        }
4298
4299        self.insert_data(model);
4300        let ptr = self.get_param(name);
4301        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
4302        let tensor_size_value = self.int_type.const_int(tensor_size, false);
4303        self.builder
4304            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
4305        self.builder
4306            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
4307        self.builder.build_return(None)?;
4308
4309        if function.verify(true) {
4310            Ok(function)
4311        } else {
4312            function.print_to_stderr();
4313            unsafe {
4314                function.delete();
4315            }
4316            Err(anyhow!("Invalid generated function."))
4317        }
4318    }
4319
4320    pub fn compile_get_constant(
4321        &mut self,
4322        model: &DiscreteModel,
4323        name: &str,
4324    ) -> Result<FunctionValue<'ctx>> {
4325        self.clear();
4326        let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
4327        let function_name = format!("get_constant_{name}");
4328        let fn_arg_names = &["tensor_data", "tensor_size"];
4329        let function = self.add_function(
4330            function_name.as_str(),
4331            fn_arg_names,
4332            &[real_ptr_ptr_type.into(), self.int_ptr_type.into()],
4333            None,
4334            false,
4335        );
4336        let _basic_block = self.start_function(function, None);
4337
4338        for (i, arg) in function.get_param_iter().enumerate() {
4339            let name = fn_arg_names[i];
4340            let alloca = self.function_arg_alloca(name, arg);
4341            self.insert_param(name, alloca);
4342        }
4343
4344        self.insert_constants(model);
4345        let ptr = self.get_param(name);
4346        let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
4347        let tensor_size_value = self.int_type.const_int(tensor_size, false);
4348        self.builder
4349            .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
4350        self.builder
4351            .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
4352        self.builder.build_return(None)?;
4353
4354        if function.verify(true) {
4355            Ok(function)
4356        } else {
4357            function.print_to_stderr();
4358            unsafe {
4359                function.delete();
4360            }
4361            Err(anyhow!("Invalid generated function."))
4362        }
4363    }
4364
4365    pub fn compile_inputs(
4366        &mut self,
4367        model: &DiscreteModel,
4368        is_get: bool,
4369    ) -> Result<FunctionValue<'ctx>> {
4370        self.clear();
4371        let function_name = if is_get { "get_inputs" } else { "set_inputs" };
4372        let fn_arg_names: &[&str] = if is_get {
4373            &["inputs", "data"]
4374        } else {
4375            &["inputs", "data", "model_index"]
4376        };
4377        let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = if is_get {
4378            &[self.real_ptr_type.into(), self.real_ptr_type.into()]
4379        } else {
4380            &[
4381                self.real_ptr_type.into(),
4382                self.real_ptr_type.into(),
4383                self.int_type.into(),
4384            ]
4385        };
4386        let function = self.add_function(function_name, fn_arg_names, fn_arg_types, None, false);
4387        let block = self.start_function(function, None);
4388
4389        for (i, arg) in function.get_param_iter().enumerate() {
4390            let name = fn_arg_names[i];
4391            let alloca = self.function_arg_alloca(name, arg);
4392            self.insert_param(name, alloca);
4393        }
4394
4395        if !is_get {
4396            let model_index = self
4397                .build_load(self.int_type, *self.get_param("model_index"), "model_index")?
4398                .into_int_value();
4399            self.builder
4400                .build_store(self.globals.model_index.as_pointer_value(), model_index)?;
4401        }
4402
4403        if let Some(input) = model.input() {
4404            let name = input.name();
4405            self.insert_tensor(input, false);
4406            let ptr = self.get_var(input);
4407            // loop thru the elements of this input and set/get them using the inputs ptr
4408            let inputs_start_index = self.int_type.const_int(0, false);
4409            let start_index = self.int_type.const_int(0, false);
4410            let end_index = self
4411                .int_type
4412                .const_int(input.nnz().try_into().unwrap(), false);
4413
4414            let input_block = self.context.append_basic_block(function, name);
4415            self.builder.build_unconditional_branch(input_block)?;
4416            self.builder.position_at_end(input_block);
4417            let index = self.builder.build_phi(self.int_type, "i")?;
4418            index.add_incoming(&[(&start_index, block)]);
4419
4420            // loop body - copy value from inputs to data
4421            let curr_input_index = index.as_basic_value().into_int_value();
4422            let input_ptr =
4423                Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
4424            let curr_inputs_index =
4425                self.builder
4426                    .build_int_add(inputs_start_index, curr_input_index, name)?;
4427            let inputs_ptr = Self::get_ptr_to_index(
4428                &self.builder,
4429                self.real_type,
4430                self.get_param("inputs"),
4431                curr_inputs_index,
4432                name,
4433            );
4434            if is_get {
4435                let input_value = self
4436                    .build_load(self.real_type, input_ptr, name)?
4437                    .into_float_value();
4438                self.builder.build_store(inputs_ptr, input_value)?;
4439            } else {
4440                let input_value = self
4441                    .build_load(self.real_type, inputs_ptr, name)?
4442                    .into_float_value();
4443                self.builder.build_store(input_ptr, input_value)?;
4444            }
4445
4446            // increment loop index
4447            let one = self.int_type.const_int(1, false);
4448            let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
4449            index.add_incoming(&[(&next_index, input_block)]);
4450
4451            // loop condition
4452            let loop_while =
4453                self.builder
4454                    .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
4455            let post_block = self.context.append_basic_block(function, name);
4456            self.builder
4457                .build_conditional_branch(loop_while, input_block, post_block)?;
4458            self.builder.position_at_end(post_block);
4459        }
4460        self.builder.build_return(None)?;
4461
4462        if function.verify(true) {
4463            Ok(function)
4464        } else {
4465            function.print_to_stderr();
4466            unsafe {
4467                function.delete();
4468            }
4469            Err(anyhow!("Invalid generated function."))
4470        }
4471    }
4472
4473    pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
4474        self.clear();
4475        let fn_arg_names = &["id"];
4476        let function = self.add_function(
4477            "set_id",
4478            fn_arg_names,
4479            &[self.real_ptr_type.into()],
4480            None,
4481            false,
4482        );
4483        let mut block = self.start_function(function, None);
4484
4485        for (i, arg) in function.get_param_iter().enumerate() {
4486            let name = fn_arg_names[i];
4487            let alloca = self.function_arg_alloca(name, arg);
4488            self.insert_param(name, alloca);
4489        }
4490
4491        let mut id_index = 0usize;
4492        for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
4493            let name = blk.name().unwrap_or("unknown");
4494            // loop thru the elements of this state blk and set the corresponding elements of id
4495            let id_start_index = self.int_type.const_int(id_index as u64, false);
4496            let blk_start_index = self.int_type.const_int(0, false);
4497            let blk_end_index = self
4498                .int_type
4499                .const_int(blk.nnz().try_into().unwrap(), false);
4500
4501            let blk_block = self.context.append_basic_block(function, name);
4502            self.builder.build_unconditional_branch(blk_block)?;
4503            self.builder.position_at_end(blk_block);
4504            let index = self.builder.build_phi(self.int_type, "i")?;
4505            index.add_incoming(&[(&blk_start_index, block)]);
4506
4507            // loop body - copy value from inputs to data
4508            let curr_blk_index = index.as_basic_value().into_int_value();
4509            let curr_id_index = self
4510                .builder
4511                .build_int_add(id_start_index, curr_blk_index, name)?;
4512            let id_ptr = Self::get_ptr_to_index(
4513                &self.builder,
4514                self.real_type,
4515                self.get_param("id"),
4516                curr_id_index,
4517                name,
4518            );
4519            let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
4520            let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
4521            self.builder.build_store(id_ptr, is_algebraic_value)?;
4522
4523            // increment loop index
4524            let one = self.int_type.const_int(1, false);
4525            let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
4526            index.add_incoming(&[(&next_index, blk_block)]);
4527
4528            // loop condition
4529            let loop_while = self.builder.build_int_compare(
4530                IntPredicate::ULT,
4531                next_index,
4532                blk_end_index,
4533                name,
4534            )?;
4535            let post_block = self.context.append_basic_block(function, name);
4536            self.builder
4537                .build_conditional_branch(loop_while, blk_block, post_block)?;
4538            self.builder.position_at_end(post_block);
4539
4540            // get ready for next blk
4541            block = post_block;
4542            id_index += blk.nnz();
4543        }
4544        self.builder.build_return(None)?;
4545
4546        if function.verify(true) {
4547            Ok(function)
4548        } else {
4549            function.print_to_stderr();
4550            unsafe {
4551                function.delete();
4552            }
4553            Err(anyhow!("Invalid generated function."))
4554        }
4555    }
4556
4557    pub fn module(&self) -> &Module<'ctx> {
4558        &self.module
4559    }
4560}