1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use core::panic;
4use inkwell::attributes::{Attribute, AttributeLoc};
5use inkwell::basic_block::BasicBlock;
6use inkwell::builder::Builder;
7use inkwell::context::{AsContextRef, Context};
8use inkwell::debug_info::AsDIScope;
9use inkwell::debug_info::{DICompileUnit, DIFlags, DIFlagsConstants, DebugInfoBuilder};
10use inkwell::execution_engine::ExecutionEngine;
11use inkwell::intrinsics::Intrinsic;
12use inkwell::module::{Linkage, Module};
13use inkwell::passes::PassBuilderOptions;
14use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
15use inkwell::types::{
16 BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
17};
18use inkwell::values::{
19 AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
20 FunctionValue, GlobalValue, IntValue, PointerValue,
21};
22use inkwell::{
23 AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
24 OptimizationLevel,
25};
26use llvm_sys::core::{
27 LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
28 LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
29};
30use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
31use pest::Span;
32use std::collections::HashMap;
33use std::ffi::CString;
34use std::iter::zip;
35use std::pin::Pin;
36use target_lexicon::Triple;
37
38use crate::ast::{Ast, AstKind};
39use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
40use crate::enzyme::{
41 CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
42 CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
43 CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
44 DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
45 EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
46 EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
47 FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
48 CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
49};
50use crate::execution::compiler::CompilerOptions;
51use crate::execution::module::{
52 CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
53};
54use crate::execution::scalar::RealType;
55use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
56use lazy_static::lazy_static;
57use std::sync::Mutex;
58
59lazy_static! {
60 static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
61}
62
63struct ImmovableLlvmModule {
64 codegen: Option<CodeGen<'static>>,
67 context: AliasableBox<Context>,
69 _pin: std::marker::PhantomPinned,
70}
71
72pub struct LlvmModule {
73 inner: Pin<Box<ImmovableLlvmModule>>,
74 machine: TargetMachine,
75}
76
77unsafe impl Send for LlvmModule {}
78unsafe impl Sync for LlvmModule {}
79
80impl LlvmModule {
81 fn new(
82 triple: Option<Triple>,
83 model: &DiscreteModel,
84 threaded: bool,
85 real_type: RealType,
86 debug: bool,
87 ) -> Result<Self> {
88 let initialization_config = &InitializationConfig::default();
89 Target::initialize_all(initialization_config);
90 let host_triple = Triple::host();
91 let (triple_str, native) = match triple {
92 Some(ref triple) => (triple.to_string(), false),
93 None => (host_triple.to_string(), true),
94 };
95 let triple = TargetTriple::create(triple_str.as_str());
96 let target = Target::from_triple(&triple).unwrap();
97 let cpu = if native {
98 TargetMachine::get_host_cpu_name().to_string()
99 } else {
100 "generic".to_string()
101 };
102 let features = if native {
103 TargetMachine::get_host_cpu_features().to_string()
104 } else {
105 "".to_string()
106 };
107 let machine = target
108 .create_target_machine(
109 &triple,
110 cpu.as_str(),
111 features.as_str(),
112 inkwell::OptimizationLevel::Aggressive,
113 inkwell::targets::RelocMode::Default,
114 inkwell::targets::CodeModel::Default,
115 )
116 .unwrap();
117
118 let context = AliasableBox::from_unique(Box::new(Context::create()));
119 let mut pinned = Self {
120 inner: Box::pin(ImmovableLlvmModule {
121 codegen: None,
122 context,
123 _pin: std::marker::PhantomPinned,
124 }),
125 machine,
126 };
127
128 let context_ref = pinned.inner.context.as_ref();
129 let real_type_llvm = match real_type {
130 RealType::F32 => context_ref.f32_type(),
131 RealType::F64 => context_ref.f64_type(),
132 };
133 let int_type_llvm = context_ref.i32_type();
134 let ptr_size_bits = pinned
135 .machine
136 .get_target_data()
137 .get_bit_size(&context_ref.ptr_type(AddressSpace::default()));
138 let real_size_bits = pinned
139 .machine
140 .get_target_data()
141 .get_bit_size(&real_type_llvm);
142 let int_size_bits = pinned
143 .machine
144 .get_target_data()
145 .get_bit_size(&int_type_llvm);
146 let codegen = CodeGen::new(
147 model,
148 context_ref,
149 real_type,
150 real_type_llvm,
151 int_type_llvm,
152 threaded,
153 ptr_size_bits,
154 real_size_bits,
155 int_size_bits,
156 debug,
157 )?;
158 let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
159 unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
160 Ok(pinned)
161 }
162
163 fn pre_autodiff_optimisation(&mut self) -> Result<()> {
164 let pass_options = PassBuilderOptions::create();
173 pass_options.set_loop_vectorization(false);
177 pass_options.set_loop_slp_vectorization(false);
178 pass_options.set_loop_unrolling(false);
179 let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
192 let (codegen, machine) = self.codegen_and_machine_mut();
193 codegen
194 .module()
195 .run_passes(passes, machine, pass_options)
196 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
197
198 }
204
205 fn post_autodiff_optimisation(&mut self) -> Result<()> {
206 if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
208 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
209 barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
210 }
211
212 for f in self.codegen_mut().module.get_functions() {
214 if f.get_name().to_str().unwrap().starts_with("preprocess_") {
215 unsafe { f.delete() };
216 }
217 }
218
219 let passes = "default<O3>";
225 let (codegen, machine) = self.codegen_and_machine_mut();
226 codegen
227 .module()
228 .run_passes(passes, machine, PassBuilderOptions::create())
229 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
230
231 Ok(())
237 }
238
239 pub fn print(&self) {
240 self.codegen().module().print_to_stderr();
241 }
242 fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
243 unsafe {
244 self.inner
245 .as_mut()
246 .get_unchecked_mut()
247 .codegen
248 .as_mut()
249 .unwrap()
250 }
251 }
252 fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
253 (
254 unsafe {
255 self.inner
256 .as_mut()
257 .get_unchecked_mut()
258 .codegen
259 .as_mut()
260 .unwrap()
261 },
262 &self.machine,
263 )
264 }
265
266 fn codegen(&self) -> &CodeGen<'static> {
267 self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
268 }
269}
270
271impl CodegenModule for LlvmModule {}
272
273impl CodegenModuleCompile for LlvmModule {
274 fn from_discrete_model(
275 model: &DiscreteModel,
276 options: CompilerOptions,
277 triple: Option<Triple>,
278 real_type: RealType,
279 code: Option<&str>,
280 ) -> Result<Self> {
281 let thread_dim = options.mode.thread_dim(model.state().nnz());
282 let threaded = thread_dim > 1;
283 if (unsafe { LLVMIsMultithreaded() } <= 0) {
284 return Err(anyhow!(
285 "LLVM is not compiled with multithreading support, but this codegen module requires it."
286 ));
287 }
288
289 let mut module = Self::new(triple, model, threaded, real_type, options.debug)?;
290
291 let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
292 let _calc_stop = module.codegen_mut().compile_calc_stop(model, code)?;
293 let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
294 let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
295 let mass = module.codegen_mut().compile_mass(model, code)?;
296 let calc_out = module.codegen_mut().compile_calc_out(model, false, code)?;
297 let calc_out_full = module.codegen_mut().compile_calc_out(model, true, code)?;
298 let _set_id = module.codegen_mut().compile_set_id(model)?;
299 let _get_dims = module.codegen_mut().compile_get_dims(model)?;
300 let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
301 let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
302 let _set_constants = module.codegen_mut().compile_set_constants(model, code)?;
303 let tensor_info = module
304 .codegen()
305 .layout
306 .tensors()
307 .map(|(name, is_constant)| (name.to_string(), is_constant))
308 .collect::<Vec<_>>();
309 for (tensor, is_constant) in tensor_info {
310 if is_constant {
311 module
312 .codegen_mut()
313 .compile_get_constant(model, tensor.as_str())?;
314 } else {
315 module
316 .codegen_mut()
317 .compile_get_tensor(model, tensor.as_str())?;
318 }
319 }
320
321 module.pre_autodiff_optimisation()?;
322
323 module.codegen_mut().compile_gradient(
324 set_u0,
325 &[
326 CompileGradientArgType::DupNoNeed,
327 CompileGradientArgType::DupNoNeed,
328 CompileGradientArgType::Const,
329 CompileGradientArgType::Const,
330 ],
331 CompileMode::Forward,
332 "set_u0_grad",
333 )?;
334
335 module.codegen_mut().compile_gradient(
336 rhs,
337 &[
338 CompileGradientArgType::Const,
339 CompileGradientArgType::DupNoNeed,
340 CompileGradientArgType::DupNoNeed,
341 CompileGradientArgType::DupNoNeed,
342 CompileGradientArgType::Const,
343 CompileGradientArgType::Const,
344 ],
345 CompileMode::Forward,
346 "rhs_grad",
347 )?;
348
349 module.codegen_mut().compile_gradient(
350 calc_out,
351 &[
352 CompileGradientArgType::Const,
353 CompileGradientArgType::DupNoNeed,
354 CompileGradientArgType::DupNoNeed,
355 CompileGradientArgType::DupNoNeed,
356 CompileGradientArgType::Const,
357 CompileGradientArgType::Const,
358 ],
359 CompileMode::Forward,
360 "calc_out_grad",
361 )?;
362 module.codegen_mut().compile_gradient(
363 set_inputs,
364 &[
365 CompileGradientArgType::DupNoNeed,
366 CompileGradientArgType::DupNoNeed,
367 ],
368 CompileMode::Forward,
369 "set_inputs_grad",
370 )?;
371
372 module.codegen_mut().compile_gradient(
373 set_u0,
374 &[
375 CompileGradientArgType::DupNoNeed,
376 CompileGradientArgType::DupNoNeed,
377 CompileGradientArgType::Const,
378 CompileGradientArgType::Const,
379 ],
380 CompileMode::Reverse,
381 "set_u0_rgrad",
382 )?;
383
384 module.codegen_mut().compile_gradient(
385 mass,
386 &[
387 CompileGradientArgType::Const,
388 CompileGradientArgType::DupNoNeed,
389 CompileGradientArgType::DupNoNeed,
390 CompileGradientArgType::DupNoNeed,
391 CompileGradientArgType::Const,
392 CompileGradientArgType::Const,
393 ],
394 CompileMode::Reverse,
395 "mass_rgrad",
396 )?;
397
398 module.codegen_mut().compile_gradient(
399 rhs,
400 &[
401 CompileGradientArgType::Const,
402 CompileGradientArgType::DupNoNeed,
403 CompileGradientArgType::DupNoNeed,
404 CompileGradientArgType::DupNoNeed,
405 CompileGradientArgType::Const,
406 CompileGradientArgType::Const,
407 ],
408 CompileMode::Reverse,
409 "rhs_rgrad",
410 )?;
411 module.codegen_mut().compile_gradient(
412 calc_out,
413 &[
414 CompileGradientArgType::Const,
415 CompileGradientArgType::DupNoNeed,
416 CompileGradientArgType::DupNoNeed,
417 CompileGradientArgType::DupNoNeed,
418 CompileGradientArgType::Const,
419 CompileGradientArgType::Const,
420 ],
421 CompileMode::Reverse,
422 "calc_out_rgrad",
423 )?;
424
425 module.codegen_mut().compile_gradient(
426 set_inputs,
427 &[
428 CompileGradientArgType::DupNoNeed,
429 CompileGradientArgType::DupNoNeed,
430 ],
431 CompileMode::Reverse,
432 "set_inputs_rgrad",
433 )?;
434
435 module.codegen_mut().compile_gradient(
436 rhs_full,
437 &[
438 CompileGradientArgType::Const,
439 CompileGradientArgType::Const,
440 CompileGradientArgType::DupNoNeed,
441 CompileGradientArgType::DupNoNeed,
442 CompileGradientArgType::Const,
443 CompileGradientArgType::Const,
444 ],
445 CompileMode::ForwardSens,
446 "rhs_sgrad",
447 )?;
448
449 module.codegen_mut().compile_gradient(
450 set_u0,
451 &[
452 CompileGradientArgType::DupNoNeed,
453 CompileGradientArgType::DupNoNeed,
454 CompileGradientArgType::Const,
455 CompileGradientArgType::Const,
456 ],
457 CompileMode::ForwardSens,
458 "set_u0_sgrad",
459 )?;
460
461 module.codegen_mut().compile_gradient(
462 calc_out_full,
463 &[
464 CompileGradientArgType::Const,
465 CompileGradientArgType::Const,
466 CompileGradientArgType::DupNoNeed,
467 CompileGradientArgType::DupNoNeed,
468 CompileGradientArgType::Const,
469 CompileGradientArgType::Const,
470 ],
471 CompileMode::ForwardSens,
472 "calc_out_sgrad",
473 )?;
474 module.codegen_mut().compile_gradient(
475 calc_out_full,
476 &[
477 CompileGradientArgType::Const,
478 CompileGradientArgType::Const,
479 CompileGradientArgType::DupNoNeed,
480 CompileGradientArgType::DupNoNeed,
481 CompileGradientArgType::Const,
482 CompileGradientArgType::Const,
483 ],
484 CompileMode::ReverseSens,
485 "calc_out_srgrad",
486 )?;
487
488 module.codegen_mut().compile_gradient(
489 rhs_full,
490 &[
491 CompileGradientArgType::Const,
492 CompileGradientArgType::Const,
493 CompileGradientArgType::DupNoNeed,
494 CompileGradientArgType::DupNoNeed,
495 CompileGradientArgType::Const,
496 CompileGradientArgType::Const,
497 ],
498 CompileMode::ReverseSens,
499 "rhs_srgrad",
500 )?;
501
502 module.post_autodiff_optimisation()?;
503 Ok(module)
504 }
505}
506
507impl CodegenModuleEmit for LlvmModule {
508 fn to_object(self) -> Result<Vec<u8>> {
509 let module = self.codegen().module();
510 let buffer = self
512 .machine
513 .write_to_memory_buffer(module, FileType::Object)
514 .unwrap()
515 .as_slice()
516 .to_vec();
517 Ok(buffer)
518 }
519}
520impl CodegenModuleJit for LlvmModule {
521 fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
522 let ee = self
523 .codegen()
524 .module()
525 .create_jit_execution_engine(OptimizationLevel::Default)
526 .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
527
528 let module = self.codegen().module();
534 let mut symbols = HashMap::new();
535 for function in module.get_functions() {
536 let name = function.get_name().to_str().unwrap();
537 let address = ee.get_function_address(name);
538 if let Ok(address) = address {
539 symbols.insert(name.to_string(), address as *const u8);
540 }
541 }
542 Ok(symbols)
543 }
544}
545
546struct Globals<'ctx> {
547 indices: Option<GlobalValue<'ctx>>,
548 constants: Option<GlobalValue<'ctx>>,
549 thread_counter: Option<GlobalValue<'ctx>>,
550}
551
552impl<'ctx> Globals<'ctx> {
553 fn new(
554 layout: &DataLayout,
555 module: &Module<'ctx>,
556 int_type: IntType<'ctx>,
557 real_type: FloatType<'ctx>,
558 threaded: bool,
559 ) -> Self {
560 let thread_counter = if threaded {
561 let tc = module.add_global(
562 int_type,
563 Some(AddressSpace::default()),
564 "enzyme_const_thread_counter",
565 );
566 let tc_value = int_type.const_zero();
573 tc.set_visibility(GlobalVisibility::Hidden);
574 tc.set_initializer(&tc_value.as_basic_value_enum());
575 Some(tc)
576 } else {
577 None
578 };
579 let constants = if layout.constants().is_empty() {
580 None
581 } else {
582 let constants_array_type =
583 real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
584 let constants = module.add_global(
585 constants_array_type,
586 Some(AddressSpace::default()),
587 "enzyme_const_constants",
588 );
589 constants.set_visibility(GlobalVisibility::Hidden);
590 constants.set_constant(false);
591 constants.set_initializer(&constants_array_type.const_zero());
592 Some(constants)
593 };
594 let indices = if layout.indices().is_empty() {
595 None
596 } else {
597 let indices_array_type =
598 int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
599 let indices_array_values = layout
600 .indices()
601 .iter()
602 .map(|&i| int_type.const_int(i as u64, true))
603 .collect::<Vec<IntValue>>();
604 let indices_value = int_type.const_array(indices_array_values.as_slice());
605 let indices = module.add_global(
606 indices_array_type,
607 Some(AddressSpace::default()),
608 "enzyme_const_indices",
609 );
610 indices.set_constant(true);
611 indices.set_visibility(GlobalVisibility::Hidden);
612 indices.set_initializer(&indices_value);
613 Some(indices)
614 };
615 Self {
616 indices,
617 thread_counter,
618 constants,
619 }
620 }
621}
622
623pub enum CompileGradientArgType {
624 Const,
625 Dup,
626 DupNoNeed,
627}
628
629pub enum CompileMode {
630 Forward,
631 ForwardSens,
632 Reverse,
633 ReverseSens,
634}
635
636pub struct CodeGen<'ctx> {
637 context: &'ctx inkwell::context::Context,
638 module: Module<'ctx>,
639 builder: Builder<'ctx>,
640 dibuilder: Option<DebugInfoBuilder<'ctx>>,
641 compile_unit: Option<DICompileUnit<'ctx>>,
642 variables: HashMap<String, PointerValue<'ctx>>,
643 functions: HashMap<String, FunctionValue<'ctx>>,
644 fn_value_opt: Option<FunctionValue<'ctx>>,
645 tensor_ptr_opt: Option<PointerValue<'ctx>>,
646 diffsl_real_type: RealType,
647 real_type: FloatType<'ctx>,
648 real_ptr_type: PointerType<'ctx>,
649 int_type: IntType<'ctx>,
650 int_ptr_type: PointerType<'ctx>,
651 ptr_size_bits: u64,
652 int_size_bits: u64,
653 real_size_bits: u64,
654 layout: DataLayout,
655 globals: Globals<'ctx>,
656 threaded: bool,
657 _ee: Option<ExecutionEngine<'ctx>>,
658}
659
660unsafe extern "C" fn fwd_handler(
661 _builder: LLVMBuilderRef,
662 _call_instruction: LLVMValueRef,
663 _gutils: *mut GradientUtils,
664 _dcall: *mut LLVMValueRef,
665 _normal_return: *mut LLVMValueRef,
666 _shadow_return: *mut LLVMValueRef,
667) -> u8 {
668 1
669}
670
671unsafe extern "C" fn rev_handler(
672 builder: LLVMBuilderRef,
673 call_instruction: LLVMValueRef,
674 gutils: *mut DiffeGradientUtils,
675 _tape: LLVMValueRef,
676) {
677 let call_block = LLVMGetInstructionParent(call_instruction);
678 let call_function = LLVMGetBasicBlockParent(call_block);
679 let module = LLVMGetGlobalParent(call_function);
680 let name_c_str = CString::new("barrier_grad").unwrap();
681 let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
682 let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
683 let barrier_num = LLVMGetArgOperand(call_instruction, 0);
684 let total_barriers = LLVMGetArgOperand(call_instruction, 1);
685 let thread_count = LLVMGetArgOperand(call_instruction, 2);
686 let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
687 let total_barriers =
688 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
689 let thread_count =
690 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
691 let mut args = [barrier_num, total_barriers, thread_count];
692 let name_c_str = CString::new("").unwrap();
693 LLVMBuildCall2(
694 builder,
695 barrier_func_type,
696 barrier_func,
697 args.as_mut_ptr(),
698 args.len() as u32,
699 name_c_str.as_ptr(),
700 );
701}
702
703#[allow(dead_code)]
704enum PrintValue<'ctx> {
705 Real(FloatValue<'ctx>),
706 Int(IntValue<'ctx>),
707}
708
709impl<'ctx> CodeGen<'ctx> {
710 #[allow(clippy::too_many_arguments)]
711 pub fn new(
712 model: &DiscreteModel,
713 context: &'ctx inkwell::context::Context,
714 diffsl_real_type: RealType,
715 real_type: FloatType<'ctx>,
716 int_type: IntType<'ctx>,
717 threaded: bool,
718 ptr_size_bits: u64,
719 real_size_bits: u64,
720 int_size_bits: u64,
721 debug: bool,
722 ) -> Result<Self> {
723 let builder = context.create_builder();
724 let layout = DataLayout::new(model);
725 let module = context.create_module(model.name());
726 let (dibuilder, compile_unit) = if debug {
727 let (dib, dic) = module.create_debug_info_builder(
728 true,
729 inkwell::debug_info::DWARFSourceLanguage::C,
730 model.name(),
731 ".",
732 "diffsl compiler",
733 true,
734 "",
735 0,
736 "",
737 inkwell::debug_info::DWARFEmissionKind::Full,
738 0,
739 false,
740 false,
741 "",
742 "",
743 );
744 (Some(dib), Some(dic))
745 } else {
746 (None, None)
747 };
748 let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
749 let real_ptr_type = Self::pointer_type(context, real_type.into());
750 let int_ptr_type = Self::pointer_type(context, int_type.into());
751 let mut ret = Self {
752 context,
753 module,
754 builder,
755 dibuilder,
756 compile_unit,
757 real_type,
758 real_ptr_type,
759 variables: HashMap::new(),
760 functions: HashMap::new(),
761 fn_value_opt: None,
762 tensor_ptr_opt: None,
763 layout,
764 diffsl_real_type,
765 int_type,
766 int_ptr_type,
767 globals,
768 threaded,
769 ptr_size_bits,
770 real_size_bits,
771 int_size_bits,
772 _ee: None,
773 };
774 if threaded {
775 ret.compile_barrier_init()?;
776 ret.compile_barrier()?;
777 ret.compile_barrier_grad()?;
778 }
781 Ok(ret)
782 }
783
784 fn start_function(
785 &mut self,
786 function: FunctionValue<'ctx>,
787 _code: Option<&str>,
788 ) -> BasicBlock<'ctx> {
789 let basic_block = self.context.append_basic_block(function, "entry");
790 self.fn_value_opt = Some(function);
791 self.builder.position_at_end(basic_block);
792
793 if let Some(dibuilder) = &self.dibuilder {
794 let scope = self
795 .fn_value_opt
796 .unwrap()
797 .get_subprogram()
798 .unwrap()
799 .as_debug_info_scope();
800 let loc = dibuilder.create_debug_location(self.context, 0, 0, scope, None);
801 self.builder.set_current_debug_location(loc);
802 }
803 basic_block
804 }
805
806 fn add_function(
807 &mut self,
808 name: &str,
809 arg_names: &[&str],
810 arg_types: &[BasicMetadataTypeEnum<'ctx>],
811 linkage: Option<Linkage>,
812 is_real_return: bool,
813 ) -> FunctionValue<'ctx> {
814 let function_type = if is_real_return {
815 self.real_type.fn_type(arg_types, false)
816 } else {
817 self.context.void_type().fn_type(arg_types, false)
818 };
819 let function = self.module.add_function(name, function_type, linkage);
820 if let (Some(dibuilder), Some(compile_unit)) = (&self.dibuilder, &self.compile_unit) {
821 let ditypes = arg_names
822 .iter()
823 .zip(arg_types.iter())
824 .map(|(&name, &ty)| {
825 let size_in_bits = if ty.is_float_type() {
826 self.real_size_bits
827 } else if ty.is_int_type() {
828 self.int_size_bits
829 } else if ty.is_pointer_type() {
830 self.ptr_size_bits
831 } else {
832 unreachable!("Unsupported argument type for debug info")
833 };
834 dibuilder
835 .create_basic_type(
836 name,
837 size_in_bits,
838 0x00,
839 <DIFlags as DIFlagsConstants>::PUBLIC,
840 )
841 .unwrap()
842 .as_type()
843 })
844 .collect::<Vec<_>>();
845 let subroutine_type = dibuilder.create_subroutine_type(
846 compile_unit.get_file(),
847 None,
848 &ditypes,
849 <DIFlags as DIFlagsConstants>::PUBLIC,
850 );
851 let func_scope = dibuilder.create_function(
852 compile_unit.as_debug_info_scope(),
853 name,
854 None,
855 compile_unit.get_file(),
856 0,
857 subroutine_type,
858 true,
859 true,
860 0,
861 <DIFlags as DIFlagsConstants>::PUBLIC,
862 true,
863 );
864 function.set_subprogram(func_scope);
865 }
866 self.functions.insert(name.to_owned(), function);
867 function
868 }
869
870 #[allow(dead_code)]
871 fn compile_print_value(
872 &mut self,
873 name: &str,
874 value: PrintValue<'ctx>,
875 ) -> Result<CallSiteValue<'_>> {
876 let printf = match self.module.get_function("printf") {
878 Some(f) => f,
879 None => self.add_function(
881 "printf",
882 &["format"],
883 &[self.int_ptr_type.into()],
884 Some(Linkage::External),
885 false,
886 ),
887 };
888 let (format_str, format_str_name) = match value {
889 PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
890 PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
891 };
892 let format_str = CString::new(format_str).unwrap();
894 let format_str_global = match self.module.get_global(format_str_name.as_str()) {
896 Some(g) => g,
897 None => {
898 let format_str = self.context.const_string(format_str.as_bytes(), true);
899 let fmt_str =
900 self.module
901 .add_global(format_str.get_type(), None, format_str_name.as_str());
902 fmt_str.set_initializer(&format_str);
903 fmt_str.set_visibility(GlobalVisibility::Hidden);
904 fmt_str
905 }
906 };
907 let format_str_ptr = self.builder.build_pointer_cast(
909 format_str_global.as_pointer_value(),
910 self.int_ptr_type,
911 "format_str_ptr",
912 )?;
913 let value: BasicMetadataValueEnum = match value {
914 PrintValue::Real(v) => v.into(),
915 PrintValue::Int(v) => v.into(),
916 };
917 self.builder
918 .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
919 .map_err(|e| anyhow!("Error building call to printf: {}", e))
920 }
921
922 fn compile_set_constants(
923 &mut self,
924 model: &DiscreteModel,
925 code: Option<&str>,
926 ) -> Result<FunctionValue<'ctx>> {
927 self.clear();
928 let fn_arg_names = &["thread_id", "thread_dim"];
929 let function = self.add_function(
930 "set_constants",
931 fn_arg_names,
932 &[self.int_type.into(), self.int_type.into()],
933 None,
934 false,
935 );
936 let _basic_block = self.start_function(function, code);
937
938 for (i, arg) in function.get_param_iter().enumerate() {
939 let name = fn_arg_names[i];
940 let alloca = self.function_arg_alloca(name, arg);
941 self.insert_param(name, alloca);
942 }
943
944 self.insert_indices();
945 self.insert_constants(model);
946
947 let mut nbarriers = 0;
948 let total_barriers = (model.constant_defns().len()) as u64;
949 #[allow(clippy::explicit_counter_loop)]
950 for a in model.constant_defns() {
951 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
952 self.jit_compile_call_barrier(nbarriers, total_barriers);
953 nbarriers += 1;
954 }
955
956 self.builder.build_return(None)?;
957
958 if function.verify(true) {
959 Ok(function)
960 } else {
961 function.print_to_stderr();
962 self.module.print_to_stderr();
963 unsafe {
964 function.delete();
965 }
966 Err(anyhow!("Invalid generated function."))
967 }
968 }
969
970 fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
971 self.clear();
972 let function = self.add_function("barrier_init", &[], &[], None, false);
973 let _entry_block = self.start_function(function, None);
974
975 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
976 self.builder
977 .build_store(thread_counter, self.int_type.const_zero())?;
978
979 self.builder.build_return(None)?;
980
981 if function.verify(true) {
982 Ok(function)
983 } else {
984 function.print_to_stderr();
985 unsafe {
986 function.delete();
987 }
988 Err(anyhow!("Invalid generated function."))
989 }
990 }
991
992 fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
993 self.clear();
994 let function = self.add_function(
995 "barrier",
996 &["barrier_num", "total_barriers", "thread_count"],
997 &[
998 self.int_type.into(),
999 self.int_type.into(),
1000 self.int_type.into(),
1001 ],
1002 None,
1003 false,
1004 );
1005 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
1006 let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
1007 function.add_attribute(AttributeLoc::Function, noinline);
1008
1009 let _entry_block = self.start_function(function, None);
1010 let increment_block = self.context.append_basic_block(function, "increment");
1011 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1012 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1013
1014 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1015 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1016 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1017 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1018
1019 let nbarrier_equals_total_barriers = self
1020 .builder
1021 .build_int_compare(
1022 IntPredicate::EQ,
1023 barrier_num,
1024 total_barriers,
1025 "nbarrier_equals_total_barriers",
1026 )
1027 .unwrap();
1028 self.builder.build_conditional_branch(
1030 nbarrier_equals_total_barriers,
1031 barrier_done_block,
1032 increment_block,
1033 )?;
1034 self.builder.position_at_end(increment_block);
1035
1036 let barrier_num_times_thread_count = self
1037 .builder
1038 .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
1039 .unwrap();
1040
1041 let i32_type = self.context.i32_type();
1043 let one = i32_type.const_int(1, false);
1044 self.builder.build_atomicrmw(
1045 AtomicRMWBinOp::Add,
1046 thread_counter,
1047 one,
1048 AtomicOrdering::Monotonic,
1049 )?;
1050
1051 self.builder.build_unconditional_branch(wait_loop_block)?;
1053 self.builder.position_at_end(wait_loop_block);
1054
1055 let current_value = self
1056 .builder
1057 .build_load(i32_type, thread_counter, "current_value")?
1058 .into_int_value();
1059
1060 current_value
1061 .as_instruction_value()
1062 .unwrap()
1063 .set_atomic_ordering(AtomicOrdering::Monotonic)
1064 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1065
1066 let all_threads_done = self.builder.build_int_compare(
1067 IntPredicate::UGE,
1068 current_value,
1069 barrier_num_times_thread_count,
1070 "all_threads_done",
1071 )?;
1072
1073 self.builder.build_conditional_branch(
1074 all_threads_done,
1075 barrier_done_block,
1076 wait_loop_block,
1077 )?;
1078 self.builder.position_at_end(barrier_done_block);
1079
1080 self.builder.build_return(None)?;
1081
1082 if function.verify(true) {
1083 Ok(function)
1084 } else {
1085 function.print_to_stderr();
1086 unsafe {
1087 function.delete();
1088 }
1089 Err(anyhow!("Invalid generated function."))
1090 }
1091 }
1092
1093 fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
1094 self.clear();
1095 let function = self.add_function(
1096 "barrier_grad",
1097 &["barrier_num", "total_barriers", "thread_count"],
1098 &[
1099 self.int_type.into(),
1100 self.int_type.into(),
1101 self.int_type.into(),
1102 ],
1103 None,
1104 false,
1105 );
1106
1107 let _entry_block = self.start_function(function, None);
1108 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1109 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1110
1111 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1112 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1113 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1114 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1115
1116 let twice_total_barriers = self
1117 .builder
1118 .build_int_mul(
1119 total_barriers,
1120 self.int_type.const_int(2, false),
1121 "twice_total_barriers",
1122 )
1123 .unwrap();
1124 let twice_total_barriers_minus_barrier_num = self
1125 .builder
1126 .build_int_sub(
1127 twice_total_barriers,
1128 barrier_num,
1129 "twice_total_barriers_minus_barrier_num",
1130 )
1131 .unwrap();
1132 let twice_total_barriers_minus_barrier_num_times_thread_count = self
1133 .builder
1134 .build_int_mul(
1135 twice_total_barriers_minus_barrier_num,
1136 thread_count,
1137 "twice_total_barriers_minus_barrier_num_times_thread_count",
1138 )
1139 .unwrap();
1140
1141 let i32_type = self.context.i32_type();
1143 let one = i32_type.const_int(1, false);
1144 self.builder.build_atomicrmw(
1145 AtomicRMWBinOp::Add,
1146 thread_counter,
1147 one,
1148 AtomicOrdering::Monotonic,
1149 )?;
1150
1151 self.builder.build_unconditional_branch(wait_loop_block)?;
1153 self.builder.position_at_end(wait_loop_block);
1154
1155 let current_value = self
1156 .builder
1157 .build_load(i32_type, thread_counter, "current_value")?
1158 .into_int_value();
1159 current_value
1160 .as_instruction_value()
1161 .unwrap()
1162 .set_atomic_ordering(AtomicOrdering::Monotonic)
1163 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1164
1165 let all_threads_done = self.builder.build_int_compare(
1166 IntPredicate::UGE,
1167 current_value,
1168 twice_total_barriers_minus_barrier_num_times_thread_count,
1169 "all_threads_done",
1170 )?;
1171
1172 self.builder.build_conditional_branch(
1173 all_threads_done,
1174 barrier_done_block,
1175 wait_loop_block,
1176 )?;
1177 self.builder.position_at_end(barrier_done_block);
1178
1179 self.builder.build_return(None)?;
1180
1181 if function.verify(true) {
1182 Ok(function)
1183 } else {
1184 function.print_to_stderr();
1185 unsafe {
1186 function.delete();
1187 }
1188 Err(anyhow!("Invalid generated function."))
1189 }
1190 }
1191
1192 fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1193 if !self.threaded {
1194 return;
1195 }
1196 let thread_dim = self.get_param("thread_dim");
1197 let thread_dim = self
1198 .builder
1199 .build_load(self.int_type, *thread_dim, "thread_dim")
1200 .unwrap()
1201 .into_int_value();
1202 let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1203 let total_barriers = self.int_type.const_int(total_barriers, false);
1204 let barrier = self.get_function("barrier").unwrap();
1205 self.builder
1206 .build_call(
1207 barrier,
1208 &[
1209 BasicMetadataValueEnum::IntValue(nbarrier),
1210 BasicMetadataValueEnum::IntValue(total_barriers),
1211 BasicMetadataValueEnum::IntValue(thread_dim),
1212 ],
1213 "barrier",
1214 )
1215 .unwrap();
1216 }
1217
1218 fn jit_threading_limits(
1219 &mut self,
1220 size: IntValue<'ctx>,
1221 ) -> Result<(
1222 IntValue<'ctx>,
1223 IntValue<'ctx>,
1224 BasicBlock<'ctx>,
1225 BasicBlock<'ctx>,
1226 )> {
1227 let one = self.int_type.const_int(1, false);
1228 let thread_id = self.get_param("thread_id");
1229 let thread_id = self
1230 .builder
1231 .build_load(self.int_type, *thread_id, "thread_id")
1232 .unwrap()
1233 .into_int_value();
1234 let thread_dim = self.get_param("thread_dim");
1235 let thread_dim = self
1236 .builder
1237 .build_load(self.int_type, *thread_dim, "thread_dim")
1238 .unwrap()
1239 .into_int_value();
1240
1241 let i_times_size = self
1243 .builder
1244 .build_int_mul(thread_id, size, "i_times_size")?;
1245 let start = self
1246 .builder
1247 .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1248
1249 let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1251 let i_plus_one_times_size =
1252 self.builder
1253 .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1254 let end = self
1255 .builder
1256 .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1257
1258 let test_done = self.builder.get_insert_block().unwrap();
1259 let next_block = self
1260 .context
1261 .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1262 self.builder.position_at_end(next_block);
1263
1264 Ok((start, end, test_done, next_block))
1265 }
1266
1267 fn jit_end_threading(
1268 &mut self,
1269 start: IntValue<'ctx>,
1270 end: IntValue<'ctx>,
1271 test_done: BasicBlock<'ctx>,
1272 next: BasicBlock<'ctx>,
1273 ) -> Result<()> {
1274 let exit = self
1275 .context
1276 .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1277 self.builder.build_unconditional_branch(exit)?;
1278 self.builder.position_at_end(test_done);
1279 let done = self
1281 .builder
1282 .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1283 self.builder.build_conditional_branch(done, exit, next)?;
1284 self.builder.position_at_end(exit);
1285 Ok(())
1286 }
1287
1288 pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1289 self.module.write_bitcode_to_path(path);
1290 }
1291
1292 fn insert_constants(&mut self, model: &DiscreteModel) {
1293 if let Some(constants) = self.globals.constants.as_ref() {
1294 self.insert_param("constants", constants.as_pointer_value());
1295 for tensor in model.constant_defns() {
1296 self.insert_tensor(tensor, true);
1297 }
1298 }
1299 }
1300
1301 fn insert_data(&mut self, model: &DiscreteModel) {
1302 self.insert_constants(model);
1303
1304 if let Some(input) = model.input() {
1305 self.insert_tensor(input, false);
1306 }
1307 for tensor in model.input_dep_defns() {
1308 self.insert_tensor(tensor, false);
1309 }
1310 for tensor in model.time_dep_defns() {
1311 self.insert_tensor(tensor, false);
1312 }
1313 for tensor in model.state_dep_defns() {
1314 self.insert_tensor(tensor, false);
1315 }
1316 }
1317
1318 fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1319 context.ptr_type(AddressSpace::default())
1320 }
1321
1322 fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1323 context.ptr_type(AddressSpace::default())
1324 }
1325
1326 fn insert_indices(&mut self) {
1327 if let Some(indices) = self.globals.indices.as_ref() {
1328 let i32_type = self.context.i32_type();
1329 let zero = i32_type.const_int(0, false);
1330 let ptr = unsafe {
1331 indices
1332 .as_pointer_value()
1333 .const_in_bounds_gep(i32_type, &[zero])
1334 };
1335 self.variables.insert("indices".to_owned(), ptr);
1336 }
1337 }
1338
1339 fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1340 self.variables.insert(name.to_owned(), value);
1341 }
1342
1343 fn build_gep<T: BasicType<'ctx>>(
1344 &self,
1345 ty: T,
1346 ptr: PointerValue<'ctx>,
1347 ordered_indexes: &[IntValue<'ctx>],
1348 name: &str,
1349 ) -> Result<PointerValue<'ctx>> {
1350 unsafe {
1351 self.builder
1352 .build_gep(ty, ptr, ordered_indexes, name)
1353 .map_err(|e| e.into())
1354 }
1355 }
1356
1357 fn build_load<T: BasicType<'ctx>>(
1358 &self,
1359 ty: T,
1360 ptr: PointerValue<'ctx>,
1361 name: &str,
1362 ) -> Result<BasicValueEnum<'ctx>> {
1363 self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1364 }
1365
1366 fn get_ptr_to_index<T: BasicType<'ctx>>(
1367 builder: &Builder<'ctx>,
1368 ty: T,
1369 ptr: &PointerValue<'ctx>,
1370 index: IntValue<'ctx>,
1371 name: &str,
1372 ) -> PointerValue<'ctx> {
1373 unsafe {
1374 builder
1375 .build_in_bounds_gep(ty, *ptr, &[index], name)
1376 .unwrap()
1377 }
1378 }
1379
1380 fn insert_state(&mut self, u: &Tensor) {
1381 let mut data_index = 0;
1382 for blk in u.elmts() {
1383 if let Some(name) = blk.name() {
1384 let ptr = self.variables.get("u").unwrap();
1385 let i = self
1386 .context
1387 .i32_type()
1388 .const_int(data_index.try_into().unwrap(), false);
1389 let alloca = Self::get_ptr_to_index(
1390 &self.create_entry_block_builder(),
1391 self.real_type,
1392 ptr,
1393 i,
1394 blk.name().unwrap(),
1395 );
1396 self.variables.insert(name.to_owned(), alloca);
1397 }
1398 data_index += blk.nnz();
1399 }
1400 }
1401 fn insert_dot_state(&mut self, dudt: &Tensor) {
1402 let mut data_index = 0;
1403 for blk in dudt.elmts() {
1404 if let Some(name) = blk.name() {
1405 let ptr = self.variables.get("dudt").unwrap();
1406 let i = self
1407 .context
1408 .i32_type()
1409 .const_int(data_index.try_into().unwrap(), false);
1410 let alloca = Self::get_ptr_to_index(
1411 &self.create_entry_block_builder(),
1412 self.real_type,
1413 ptr,
1414 i,
1415 blk.name().unwrap(),
1416 );
1417 self.variables.insert(name.to_owned(), alloca);
1418 }
1419 data_index += blk.nnz();
1420 }
1421 }
1422 fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1423 let var_name = if is_constant { "constants" } else { "data" };
1424 let ptr = *self.variables.get(var_name).unwrap();
1425 let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1426 let i = self
1427 .context
1428 .i32_type()
1429 .const_int(data_index.try_into().unwrap(), false);
1430 let alloca = Self::get_ptr_to_index(
1431 &self.create_entry_block_builder(),
1432 self.real_type,
1433 &ptr,
1434 i,
1435 tensor.name(),
1436 );
1437 self.variables.insert(tensor.name().to_owned(), alloca);
1438
1439 for blk in tensor.elmts() {
1441 if let Some(name) = blk.name() {
1442 let i = self
1443 .context
1444 .i32_type()
1445 .const_int(data_index.try_into().unwrap(), false);
1446 let alloca = Self::get_ptr_to_index(
1447 &self.create_entry_block_builder(),
1448 self.real_type,
1449 &ptr,
1450 i,
1451 name,
1452 );
1453 self.variables.insert(name.to_owned(), alloca);
1454 }
1455 data_index += blk.nnz();
1457 }
1458 }
1459 fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1460 self.variables.get(name).unwrap()
1461 }
1462
1463 fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1464 self.variables.get(tensor.name()).unwrap()
1465 }
1466
1467 fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1468 if let Some(&func) = self.functions.get(name) {
1470 return Some(func);
1471 }
1472 let current_block = self.builder.get_insert_block().unwrap();
1473 let current_loc = self.builder.get_current_debug_location().unwrap();
1474 let current_function = self.fn_value_opt.unwrap();
1475
1476 let function = match name {
1477 "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" | "copysign"
1479 | "pow" | "min" | "max" => {
1480 let intrinsic_name = match name {
1481 "min" => "minnum",
1482 "max" => "maxnum",
1483 "abs" => "fabs",
1484 _ => name,
1485 };
1486 let llvm_name =
1487 format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
1488
1489 let args_types: Vec<BasicTypeEnum> = vec![self.real_type.into()];
1490 self.builder.unset_current_debug_location();
1491
1492 Some(match Intrinsic::find(&llvm_name) {
1494 Some(intrinsic) => intrinsic.get_declaration(&self.module, &args_types)?,
1495 None => {
1496 let args_types_meta: Vec<BasicMetadataTypeEnum> =
1498 vec![self.real_type.into()];
1499 let function_type = self.real_type.fn_type(&args_types_meta, false);
1500 self.module
1501 .add_function(name, function_type, Some(Linkage::External))
1502 }
1503 })
1504 }
1505 "sigmoid" => {
1507 let arg_len = 1;
1508 let ret_type = self.real_type;
1509
1510 let args_types = std::iter::repeat_n(ret_type, arg_len)
1511 .map(|f| f.into())
1512 .collect::<Vec<BasicMetadataTypeEnum>>();
1513
1514 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1515
1516 for arg in fn_val.get_param_iter() {
1517 arg.into_float_value().set_name("x");
1518 }
1519 let _basic_block = self.start_function(fn_val, None);
1520
1521 let x = fn_val.get_nth_param(0)?.into_float_value();
1522 let one = self.real_type.const_float(1.0);
1523 let negx = self.builder.build_float_neg(x, name).ok()?;
1524 let exp = self.get_function("exp").unwrap();
1525 let exp_negx = self
1526 .builder
1527 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1528 .ok()?;
1529 let one_plus_exp_negx = self
1530 .builder
1531 .build_float_add(
1532 exp_negx
1533 .try_as_basic_value()
1534 .unwrap_basic()
1535 .into_float_value(),
1536 one,
1537 name,
1538 )
1539 .ok()?;
1540 let sigmoid = self
1541 .builder
1542 .build_float_div(one, one_plus_exp_negx, name)
1543 .ok()?;
1544 self.builder.build_return(Some(&sigmoid)).ok();
1545 Some(fn_val)
1546 }
1547 "arcsinh" | "arccosh" => {
1548 let arg_len = 1;
1549 let ret_type = self.real_type;
1550
1551 let args_types = std::iter::repeat_n(ret_type, arg_len)
1552 .map(|f| f.into())
1553 .collect::<Vec<BasicMetadataTypeEnum>>();
1554
1555 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1556
1557 for arg in fn_val.get_param_iter() {
1558 arg.into_float_value().set_name("x");
1559 }
1560
1561 let _basic_block = self.start_function(fn_val, None);
1562 let x = fn_val.get_nth_param(0)?.into_float_value();
1563 let one = match name {
1564 "arccosh" => self.real_type.const_float(-1.0),
1565 "arcsinh" => self.real_type.const_float(1.0),
1566 _ => panic!("unknown function"),
1567 };
1568 let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1569 let one_plus_x_squared = self.builder.build_float_add(x_squared, one, name).ok()?;
1570 let sqrt = self.get_function("sqrt").unwrap();
1571 let sqrt_one_plus_x_squared = self
1572 .builder
1573 .build_call(
1574 sqrt,
1575 &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1576 name,
1577 )
1578 .unwrap()
1579 .try_as_basic_value()
1580 .unwrap_basic()
1581 .into_float_value();
1582 let x_plus_sqrt_one_plus_x_squared = self
1583 .builder
1584 .build_float_add(x, sqrt_one_plus_x_squared, name)
1585 .ok()?;
1586 let ln = self.get_function("log").unwrap();
1587 let result = self
1588 .builder
1589 .build_call(
1590 ln,
1591 &[BasicMetadataValueEnum::FloatValue(
1592 x_plus_sqrt_one_plus_x_squared,
1593 )],
1594 name,
1595 )
1596 .unwrap()
1597 .try_as_basic_value()
1598 .unwrap_basic()
1599 .into_float_value();
1600 self.builder.build_return(Some(&result)).ok();
1601 Some(fn_val)
1602 }
1603 "heaviside" => {
1604 let arg_len = 1;
1605 let ret_type = self.real_type;
1606
1607 let args_types = std::iter::repeat_n(ret_type, arg_len)
1608 .map(|f| f.into())
1609 .collect::<Vec<BasicMetadataTypeEnum>>();
1610
1611 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1612
1613 for arg in fn_val.get_param_iter() {
1614 arg.into_float_value().set_name("x");
1615 }
1616
1617 let _basic_block = self.start_function(fn_val, None);
1618 let x = fn_val.get_nth_param(0)?.into_float_value();
1619 let zero = self.real_type.const_float(0.0);
1620 let one = self.real_type.const_float(1.0);
1621 let result = self
1622 .builder
1623 .build_select(
1624 self.builder
1625 .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1626 .unwrap(),
1627 one,
1628 zero,
1629 name,
1630 )
1631 .ok()?;
1632 self.builder.build_return(Some(&result)).ok();
1633 Some(fn_val)
1634 }
1635 "tanh" | "sinh" | "cosh" => {
1636 let arg_len = 1;
1637 let ret_type = self.real_type;
1638
1639 let args_types = std::iter::repeat_n(ret_type, arg_len)
1640 .map(|f| f.into())
1641 .collect::<Vec<BasicMetadataTypeEnum>>();
1642
1643 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1644
1645 for arg in fn_val.get_param_iter() {
1646 arg.into_float_value().set_name("x");
1647 }
1648
1649 let _basic_block = self.start_function(fn_val, None);
1650 let x = fn_val.get_nth_param(0)?.into_float_value();
1651 let negx = self.builder.build_float_neg(x, name).ok()?;
1652 let exp = self.get_function("exp").unwrap();
1653 let exp_negx = self
1654 .builder
1655 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1656 .ok()?;
1657 let expx = self
1658 .builder
1659 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1660 .ok()?;
1661 let expx_minus_exp_negx = self
1662 .builder
1663 .build_float_sub(
1664 expx.try_as_basic_value().unwrap_basic().into_float_value(),
1665 exp_negx
1666 .try_as_basic_value()
1667 .unwrap_basic()
1668 .into_float_value(),
1669 name,
1670 )
1671 .ok()?;
1672 let expx_plus_exp_negx = self
1673 .builder
1674 .build_float_add(
1675 expx.try_as_basic_value().unwrap_basic().into_float_value(),
1676 exp_negx
1677 .try_as_basic_value()
1678 .unwrap_basic()
1679 .into_float_value(),
1680 name,
1681 )
1682 .ok()?;
1683 let result = match name {
1684 "tanh" => self
1685 .builder
1686 .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1687 .ok()?,
1688 "sinh" => self
1689 .builder
1690 .build_float_div(expx_minus_exp_negx, self.real_type.const_float(2.0), name)
1691 .ok()?,
1692 "cosh" => self
1693 .builder
1694 .build_float_div(expx_plus_exp_negx, self.real_type.const_float(2.0), name)
1695 .ok()?,
1696 _ => panic!("unknown function"),
1697 };
1698 self.builder.build_return(Some(&result)).ok();
1699 Some(fn_val)
1700 }
1701 _ => None,
1702 }?;
1703
1704 if !function.verify(true) {
1705 function.print_to_stderr();
1706 panic!("Invalid generated function for {}", name);
1707 }
1708
1709 self.builder.position_at_end(current_block);
1710 if self.dibuilder.is_some() {
1711 self.builder.set_current_debug_location(current_loc);
1712 }
1713 self.fn_value_opt = Some(current_function);
1714
1715 self.functions.insert(name.to_owned(), function);
1716 Some(function)
1717 }
1718
1719 #[inline]
1721 fn fn_value(&self) -> FunctionValue<'ctx> {
1722 self.fn_value_opt.unwrap()
1723 }
1724
1725 #[inline]
1726 fn tensor_ptr(&self) -> PointerValue<'ctx> {
1727 self.tensor_ptr_opt.unwrap()
1728 }
1729
1730 fn create_entry_block_builder(&self) -> Builder<'ctx> {
1732 let builder = self.context.create_builder();
1733 let entry = self.fn_value().get_first_basic_block().unwrap();
1734 match entry.get_first_instruction() {
1735 Some(first_instr) => builder.position_before(&first_instr),
1736 None => builder.position_at_end(entry),
1737 }
1738 builder
1739 }
1740
1741 fn jit_compile_scalar(
1742 &mut self,
1743 a: &Tensor,
1744 res_ptr_opt: Option<PointerValue<'ctx>>,
1745 ) -> Result<PointerValue<'ctx>> {
1746 let res_type = self.real_type;
1747 let res_ptr = match res_ptr_opt {
1748 Some(ptr) => ptr,
1749 None => self
1750 .create_entry_block_builder()
1751 .build_alloca(res_type, a.name())?,
1752 };
1753 let name = a.name();
1754 let elmt = a.elmts().first().unwrap();
1755
1756 let curr_block = self.builder.get_insert_block().unwrap();
1758 let mut next_block_opt = None;
1759 if self.threaded {
1760 let next_block = self.context.append_basic_block(self.fn_value(), "next");
1761 self.builder.position_at_end(next_block);
1762 next_block_opt = Some(next_block);
1763 }
1764
1765 let zero = self.int_type.const_zero();
1766 let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
1767 self.builder.build_store(res_ptr, float_value)?;
1768
1769 if self.threaded {
1771 let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1772 self.builder.build_unconditional_branch(exit_block)?;
1773 self.builder.position_at_end(curr_block);
1774
1775 let thread_id = self.get_param("thread_id");
1776 let thread_id = self
1777 .builder
1778 .build_load(self.int_type, *thread_id, "thread_id")
1779 .unwrap()
1780 .into_int_value();
1781 let is_first_thread = self.builder.build_int_compare(
1782 IntPredicate::EQ,
1783 thread_id,
1784 self.int_type.const_zero(),
1785 "is_first_thread",
1786 )?;
1787 self.builder.build_conditional_branch(
1788 is_first_thread,
1789 next_block_opt.unwrap(),
1790 exit_block,
1791 )?;
1792
1793 self.builder.position_at_end(exit_block);
1794 }
1795
1796 Ok(res_ptr)
1797 }
1798
1799 fn jit_compile_tensor(
1800 &mut self,
1801 a: &Tensor,
1802 res_ptr_opt: Option<PointerValue<'ctx>>,
1803 code: Option<&str>,
1804 ) -> Result<PointerValue<'ctx>> {
1805 if a.rank() == 0 {
1807 return self.jit_compile_scalar(a, res_ptr_opt);
1808 }
1809
1810 let res_type = self.real_type;
1811 let res_ptr = match res_ptr_opt {
1812 Some(ptr) => ptr,
1813 None => self
1814 .create_entry_block_builder()
1815 .build_alloca(res_type, a.name())?,
1816 };
1817
1818 self.tensor_ptr_opt = Some(res_ptr);
1820
1821 for (i, blk) in a.elmts().iter().enumerate() {
1822 let default = format!("{}-{}", a.name(), i);
1823 let name = blk.name().unwrap_or(default.as_str());
1824 self.jit_compile_block(name, a, blk, code)?;
1825 }
1826 Ok(res_ptr)
1827 }
1828
1829 fn jit_compile_block(
1830 &mut self,
1831 name: &str,
1832 tensor: &Tensor,
1833 elmt: &TensorBlock,
1834 code: Option<&str>,
1835 ) -> Result<()> {
1836 let translation = Translation::new(
1837 elmt.expr_layout(),
1838 elmt.layout(),
1839 elmt.start(),
1840 tensor.layout_ptr(),
1841 );
1842
1843 if let Some(dibuilder) = &self.dibuilder {
1844 let (line_no, column_no) = match (elmt.expr().span, code) {
1845 (Some(s), Some(c)) => {
1846 let s = Span::new(c, s.pos_start, s.pos_end).unwrap();
1847 s.start_pos().line_col()
1848 }
1849 _ => (0, 0),
1850 };
1851 let scope = self
1852 .fn_value_opt
1853 .unwrap()
1854 .get_subprogram()
1855 .unwrap()
1856 .as_debug_info_scope();
1857 let loc = dibuilder.create_debug_location(
1858 self.context,
1859 line_no.try_into().unwrap(),
1860 column_no.try_into().unwrap(),
1861 scope,
1862 None,
1863 );
1864 self.builder.set_current_debug_location(loc);
1865 }
1866
1867 if elmt.expr_layout().is_dense() {
1868 self.jit_compile_dense_block(name, elmt, &translation)
1869 } else if elmt.expr_layout().is_diagonal() {
1870 self.jit_compile_diagonal_block(name, elmt, &translation)
1871 } else if elmt.expr_layout().is_sparse() {
1872 match translation.source {
1873 TranslationFrom::SparseContraction { .. } => {
1874 self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1875 }
1876 _ => self.jit_compile_sparse_block(name, elmt, &translation),
1877 }
1878 } else {
1879 Err(anyhow!(
1880 "unsupported block layout: {:?}",
1881 elmt.expr_layout()
1882 ))
1883 }
1884 }
1885
1886 fn jit_compile_dense_block(
1888 &mut self,
1889 name: &str,
1890 elmt: &TensorBlock,
1891 translation: &Translation,
1892 ) -> Result<()> {
1893 let int_type = self.int_type;
1894
1895 let mut preblock = self.builder.get_insert_block().unwrap();
1896 let expr_rank = elmt.expr_layout().rank();
1897 let expr_shape = elmt
1898 .expr_layout()
1899 .shape()
1900 .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1901 let one = int_type.const_int(1, false);
1902
1903 let mut expr_strides = vec![1; expr_rank];
1904 if expr_rank > 0 {
1905 for i in (0..expr_rank - 1).rev() {
1906 expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1907 }
1908 }
1909 let expr_strides = expr_strides
1910 .iter()
1911 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1912 .collect::<Vec<IntValue>>();
1913
1914 let mut indices = Vec::new();
1916 let mut blocks = Vec::new();
1917
1918 let (contract_sum, contract_by, contract_strides) =
1920 if let TranslationFrom::DenseContraction {
1921 contract_by,
1922 contract_len: _,
1923 } = translation.source
1924 {
1925 let contract_rank = expr_rank - contract_by;
1926 let mut contract_strides = vec![1; contract_rank];
1927 for i in (0..contract_rank - 1).rev() {
1928 contract_strides[i] =
1929 contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1930 }
1931 let contract_strides = contract_strides
1932 .iter()
1933 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1934 .collect::<Vec<IntValue>>();
1935 (
1936 Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1937 contract_by,
1938 Some(contract_strides),
1939 )
1940 } else {
1941 (None, 0, None)
1942 };
1943
1944 let (thread_start, thread_end, test_done, next) = if self.threaded {
1946 let (start, end, test_done, next) =
1947 self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1948 preblock = next;
1949 (Some(start), Some(end), Some(test_done), Some(next))
1950 } else {
1951 (None, None, None, None)
1952 };
1953
1954 for i in 0..expr_rank {
1955 let block = self.context.append_basic_block(self.fn_value(), name);
1956 self.builder.build_unconditional_branch(block)?;
1957 self.builder.position_at_end(block);
1958
1959 let start_index = if i == 0 && self.threaded {
1960 thread_start.unwrap()
1961 } else {
1962 self.int_type.const_zero()
1963 };
1964
1965 let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1966 curr_index.add_incoming(&[(&start_index, preblock)]);
1967
1968 if i == expr_rank - contract_by - 1 {
1969 if let Some(contract_sum) = contract_sum {
1970 self.builder
1971 .build_store(contract_sum, self.real_type.const_zero())?;
1972 }
1973 }
1974
1975 indices.push(curr_index);
1976 blocks.push(block);
1977 preblock = block;
1978 }
1979
1980 let indices_int: Vec<IntValue> = indices
1981 .iter()
1982 .map(|i| i.as_basic_value().into_int_value())
1983 .collect();
1984
1985 let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
1987 let mut stride = 1u64;
1988 if !indices.is_empty() {
1989 for i in (0..indices.len() - 1).rev() {
1990 let iname_i = indices_int[i];
1991 let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
1992 stride *= shapei;
1993 let stride_intval = self.context.i32_type().const_int(stride, false);
1994 let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
1995 expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
1996 }
1997 }
1998
1999 let float_value =
2000 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2001
2002 if let Some(contract_sum) = contract_sum {
2003 let contract_sum_value = self
2004 .build_load(self.real_type, contract_sum, "contract_sum")?
2005 .into_float_value();
2006 let new_contract_sum_value = self.builder.build_float_add(
2007 contract_sum_value,
2008 float_value,
2009 "new_contract_sum",
2010 )?;
2011 self.builder
2012 .build_store(contract_sum, new_contract_sum_value)?;
2013 } else {
2014 let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
2015 self.int_type.const_zero(),
2016 |acc, (i, s)| {
2017 let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
2018 self.builder.build_int_add(acc, tmp, "acc").unwrap()
2019 },
2020 );
2021 self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
2022 }
2023
2024 let mut postblock = self.builder.get_insert_block().unwrap();
2025
2026 for i in (0..expr_rank).rev() {
2028 let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
2030 indices[i].add_incoming(&[(&next_index, postblock)]);
2031 if i == expr_rank - contract_by - 1 {
2032 if let Some(contract_sum) = contract_sum {
2033 let contract_sum_value = self
2034 .build_load(self.real_type, contract_sum, "contract_sum")?
2035 .into_float_value();
2036 let contract_strides = contract_strides.as_ref().unwrap();
2037 let elmt_index = indices_int
2038 .iter()
2039 .take(contract_strides.len())
2040 .zip(contract_strides.iter())
2041 .fold(self.int_type.const_zero(), |acc, (i, s)| {
2042 let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
2043 self.builder.build_int_add(acc, tmp, "acc").unwrap()
2044 });
2045 self.jit_compile_store(
2046 name,
2047 elmt,
2048 elmt_index,
2049 contract_sum_value,
2050 translation,
2051 )?;
2052 }
2053 }
2054
2055 let end_index = if i == 0 && self.threaded {
2056 thread_end.unwrap()
2057 } else {
2058 expr_shape[i]
2059 };
2060
2061 let loop_while =
2063 self.builder
2064 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
2065 let block = self.context.append_basic_block(self.fn_value(), name);
2066 self.builder
2067 .build_conditional_branch(loop_while, blocks[i], block)?;
2068 self.builder.position_at_end(block);
2069 postblock = block;
2070 }
2071
2072 if self.threaded {
2073 self.jit_end_threading(
2074 thread_start.unwrap(),
2075 thread_end.unwrap(),
2076 test_done.unwrap(),
2077 next.unwrap(),
2078 )?;
2079 }
2080 Ok(())
2081 }
2082
2083 fn jit_compile_sparse_contraction_block(
2084 &mut self,
2085 name: &str,
2086 elmt: &TensorBlock,
2087 translation: &Translation,
2088 ) -> Result<()> {
2089 match translation.source {
2090 TranslationFrom::SparseContraction { .. } => {}
2091 _ => {
2092 panic!("expected sparse contraction")
2093 }
2094 }
2095 let int_type = self.int_type;
2096
2097 let translation_index = self
2098 .layout
2099 .get_translation_index(elmt.expr_layout(), elmt.layout())
2100 .unwrap();
2101 let translation_index = translation_index + translation.get_from_index_in_data_layout();
2102
2103 let final_contract_index =
2104 int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
2105 let (thread_start, thread_end, test_done, next) = if self.threaded {
2106 let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
2107 (Some(start), Some(end), Some(test_done), Some(next))
2108 } else {
2109 (None, None, None, None)
2110 };
2111
2112 let preblock = self.builder.get_insert_block().unwrap();
2113 let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
2114
2115 let block = self.context.append_basic_block(self.fn_value(), name);
2117 self.builder.build_unconditional_branch(block)?;
2118 self.builder.position_at_end(block);
2119
2120 let contract_index = self.builder.build_phi(int_type, "i")?;
2121 let contract_start = if self.threaded {
2122 thread_start.unwrap()
2123 } else {
2124 int_type.const_zero()
2125 };
2126 contract_index.add_incoming(&[(&contract_start, preblock)]);
2127
2128 let start_index = self.builder.build_int_add(
2129 int_type.const_int(translation_index.try_into().unwrap(), false),
2130 self.builder.build_int_mul(
2131 int_type.const_int(2, false),
2132 contract_index.as_basic_value().into_int_value(),
2133 name,
2134 )?,
2135 name,
2136 )?;
2137 let end_index =
2138 self.builder
2139 .build_int_add(start_index, int_type.const_int(1, false), name)?;
2140 let start_ptr = self.build_gep(
2141 self.int_type,
2142 *self.get_param("indices"),
2143 &[start_index],
2144 "start_index_ptr",
2145 )?;
2146 let start_contract = self
2147 .build_load(self.int_type, start_ptr, "start")?
2148 .into_int_value();
2149 let end_ptr = self.build_gep(
2150 self.int_type,
2151 *self.get_param("indices"),
2152 &[end_index],
2153 "end_index_ptr",
2154 )?;
2155 let end_contract = self
2156 .build_load(self.int_type, end_ptr, "end")?
2157 .into_int_value();
2158
2159 self.builder
2161 .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
2162
2163 let start_contract_block = self
2165 .context
2166 .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
2167 self.builder
2168 .build_unconditional_branch(start_contract_block)?;
2169 self.builder.position_at_end(start_contract_block);
2170
2171 let expr_index_phi = self.builder.build_phi(int_type, "j")?;
2172 expr_index_phi.add_incoming(&[(&start_contract, block)]);
2173
2174 let expr_index = expr_index_phi.as_basic_value().into_int_value();
2175 let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
2176
2177 let float_value =
2179 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2180 let contract_sum_value = self
2181 .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
2182 .into_float_value();
2183 let new_contract_sum_value =
2184 self.builder
2185 .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
2186 self.builder
2187 .build_store(contract_sum_ptr, new_contract_sum_value)?;
2188
2189 let end_contract_block = self.builder.get_insert_block().unwrap();
2190
2191 let next_elmt_index =
2193 self.builder
2194 .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2195 expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
2196
2197 let loop_while = self.builder.build_int_compare(
2199 IntPredicate::ULT,
2200 next_elmt_index,
2201 end_contract,
2202 name,
2203 )?;
2204 let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2205 self.builder.build_conditional_branch(
2206 loop_while,
2207 start_contract_block,
2208 post_contract_block,
2209 )?;
2210 self.builder.position_at_end(post_contract_block);
2211
2212 self.jit_compile_store(
2214 name,
2215 elmt,
2216 contract_index.as_basic_value().into_int_value(),
2217 new_contract_sum_value,
2218 translation,
2219 )?;
2220
2221 let next_contract_index = self.builder.build_int_add(
2223 contract_index.as_basic_value().into_int_value(),
2224 int_type.const_int(1, false),
2225 name,
2226 )?;
2227 contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2228
2229 let loop_while = self.builder.build_int_compare(
2231 IntPredicate::ULT,
2232 next_contract_index,
2233 thread_end.unwrap_or(final_contract_index),
2234 name,
2235 )?;
2236 let post_block = self.context.append_basic_block(self.fn_value(), name);
2237 self.builder
2238 .build_conditional_branch(loop_while, block, post_block)?;
2239 self.builder.position_at_end(post_block);
2240
2241 if self.threaded {
2242 self.jit_end_threading(
2243 thread_start.unwrap(),
2244 thread_end.unwrap(),
2245 test_done.unwrap(),
2246 next.unwrap(),
2247 )?;
2248 }
2249
2250 Ok(())
2251 }
2252
2253 fn expr_indices_from_elmt_index(
2254 &mut self,
2255 elmt_index: IntValue<'ctx>,
2256 elmt: &TensorBlock,
2257 name: &str,
2258 ) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
2259 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2260 let int_type = self.int_type;
2261 let elmt_index_mult_rank = self.builder.build_int_mul(
2263 elmt_index,
2264 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2265 name,
2266 )?;
2267 (0..elmt.expr_layout().rank())
2268 .map(|i| {
2269 let layout_index_plus_offset =
2270 int_type.const_int((layout_index + i).try_into().unwrap(), false);
2271 let curr_index = self.builder.build_int_add(
2272 elmt_index_mult_rank,
2273 layout_index_plus_offset,
2274 name,
2275 )?;
2276 let ptr = Self::get_ptr_to_index(
2277 &self.builder,
2278 self.int_type,
2279 self.get_param("indices"),
2280 curr_index,
2281 name,
2282 );
2283 Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2284 })
2285 .collect::<Result<Vec<_>, anyhow::Error>>()
2286 }
2287
2288 fn jit_compile_sparse_block(
2291 &mut self,
2292 name: &str,
2293 elmt: &TensorBlock,
2294 translation: &Translation,
2295 ) -> Result<()> {
2296 let int_type = self.int_type;
2297
2298 let start_index = int_type.const_int(0, false);
2299 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2300
2301 let (thread_start, thread_end, test_done, next) = if self.threaded {
2302 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2303 (Some(start), Some(end), Some(test_done), Some(next))
2304 } else {
2305 (None, None, None, None)
2306 };
2307
2308 let preblock = self.builder.get_insert_block().unwrap();
2310 let loop_block = self.context.append_basic_block(self.fn_value(), name);
2311 self.builder.build_unconditional_branch(loop_block)?;
2312 self.builder.position_at_end(loop_block);
2313
2314 let curr_index = self.builder.build_phi(int_type, "i")?;
2315 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2316
2317 let elmt_index = curr_index.as_basic_value().into_int_value();
2318 let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
2319
2320 let float_value =
2322 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2323
2324 self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2325
2326 let end_loop_block = self.builder.get_insert_block().unwrap();
2328
2329 let one = int_type.const_int(1, false);
2331 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2332 curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2333
2334 let loop_while = self.builder.build_int_compare(
2336 IntPredicate::ULT,
2337 next_index,
2338 thread_end.unwrap_or(end_index),
2339 name,
2340 )?;
2341 let post_block = self.context.append_basic_block(self.fn_value(), name);
2342 self.builder
2343 .build_conditional_branch(loop_while, loop_block, post_block)?;
2344 self.builder.position_at_end(post_block);
2345
2346 if self.threaded {
2347 self.jit_end_threading(
2348 thread_start.unwrap(),
2349 thread_end.unwrap(),
2350 test_done.unwrap(),
2351 next.unwrap(),
2352 )?;
2353 }
2354
2355 Ok(())
2356 }
2357
2358 fn jit_compile_diagonal_block(
2360 &mut self,
2361 name: &str,
2362 elmt: &TensorBlock,
2363 translation: &Translation,
2364 ) -> Result<()> {
2365 let int_type = self.int_type;
2366
2367 let start_index = int_type.const_int(0, false);
2368 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2369
2370 let (thread_start, thread_end, test_done, next) = if self.threaded {
2371 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2372 (Some(start), Some(end), Some(test_done), Some(next))
2373 } else {
2374 (None, None, None, None)
2375 };
2376
2377 let preblock = self.builder.get_insert_block().unwrap();
2379 let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
2380 self.builder.build_unconditional_branch(start_loop_block)?;
2381 self.builder.position_at_end(start_loop_block);
2382
2383 let curr_index = self.builder.build_phi(int_type, "i")?;
2384 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2385
2386 let elmt_index = curr_index.as_basic_value().into_int_value();
2388 let indices_int: Vec<IntValue> =
2389 (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2390
2391 let float_value =
2393 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2394
2395 self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2397
2398 let end_loop_block = self.builder.get_insert_block().unwrap();
2399
2400 let one = int_type.const_int(1, false);
2402 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2403 curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2404
2405 let loop_while = self.builder.build_int_compare(
2407 IntPredicate::ULT,
2408 next_index,
2409 thread_end.unwrap_or(end_index),
2410 name,
2411 )?;
2412 let post_block = self.context.append_basic_block(self.fn_value(), name);
2413 self.builder
2414 .build_conditional_branch(loop_while, start_loop_block, post_block)?;
2415 self.builder.position_at_end(post_block);
2416
2417 if self.threaded {
2418 self.jit_end_threading(
2419 thread_start.unwrap(),
2420 thread_end.unwrap(),
2421 test_done.unwrap(),
2422 next.unwrap(),
2423 )?;
2424 }
2425
2426 Ok(())
2427 }
2428
2429 fn jit_compile_broadcast_and_store(
2430 &mut self,
2431 name: &str,
2432 elmt: &TensorBlock,
2433 float_value: FloatValue<'ctx>,
2434 expr_index: IntValue<'ctx>,
2435 translation: &Translation,
2436 ) -> Result<()> {
2437 let int_type = self.int_type;
2438 let one = int_type.const_int(1, false);
2439 let zero = int_type.const_int(0, false);
2440 let pre_block = self.builder.get_insert_block().unwrap();
2441 match translation.source {
2442 TranslationFrom::Broadcast {
2443 broadcast_by: _,
2444 broadcast_len,
2445 } => {
2446 let bcast_start_index = zero;
2447 let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2448
2449 let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2451 self.builder.build_unconditional_branch(bcast_block)?;
2452 self.builder.position_at_end(bcast_block);
2453 let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2454 bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2455
2456 let store_index = self.builder.build_int_add(
2458 self.builder
2459 .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2460 bcast_index.as_basic_value().into_int_value(),
2461 "bcast_store_index",
2462 )?;
2463 self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2464
2465 let bcast_next_index = self.builder.build_int_add(
2467 bcast_index.as_basic_value().into_int_value(),
2468 one,
2469 name,
2470 )?;
2471 bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2472
2473 let bcast_cond = self.builder.build_int_compare(
2475 IntPredicate::ULT,
2476 bcast_next_index,
2477 bcast_end_index,
2478 "broadcast_cond",
2479 )?;
2480 let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2481 self.builder
2482 .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2483 self.builder.position_at_end(post_bcast_block);
2484 Ok(())
2485 }
2486 TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2487 self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2488 Ok(())
2489 }
2490 _ => Err(anyhow!("Invalid translation")),
2491 }
2492 }
2493
2494 fn jit_compile_store(
2495 &mut self,
2496 name: &str,
2497 elmt: &TensorBlock,
2498 store_index: IntValue<'ctx>,
2499 float_value: FloatValue<'ctx>,
2500 translation: &Translation,
2501 ) -> Result<()> {
2502 let int_type = self.int_type;
2503 let res_index = match &translation.target {
2504 TranslationTo::Contiguous { start, end: _ } => {
2505 let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2506 self.builder.build_int_add(start_const, store_index, name)?
2507 }
2508 TranslationTo::Sparse { indices: _ } => {
2509 let translate_index = self
2511 .layout
2512 .get_translation_index(elmt.expr_layout(), elmt.layout())
2513 .unwrap();
2514 let translate_store_index =
2515 translate_index + translation.get_to_index_in_data_layout();
2516 let translate_store_index =
2517 int_type.const_int(translate_store_index.try_into().unwrap(), false);
2518 let elmt_index_strided = store_index;
2519 let curr_index =
2520 self.builder
2521 .build_int_add(elmt_index_strided, translate_store_index, name)?;
2522 let ptr = Self::get_ptr_to_index(
2523 &self.builder,
2524 self.int_type,
2525 self.get_param("indices"),
2526 curr_index,
2527 name,
2528 );
2529 self.build_load(self.int_type, ptr, name)?.into_int_value()
2530 }
2531 };
2532
2533 let resi_ptr = Self::get_ptr_to_index(
2534 &self.builder,
2535 self.real_type,
2536 &self.tensor_ptr(),
2537 res_index,
2538 name,
2539 );
2540 self.builder.build_store(resi_ptr, float_value)?;
2541 Ok(())
2542 }
2543
2544 fn jit_compile_expr(
2545 &mut self,
2546 name: &str,
2547 expr: &Ast,
2548 index: &[IntValue<'ctx>],
2549 elmt: &TensorBlock,
2550 expr_index: IntValue<'ctx>,
2551 ) -> Result<FloatValue<'ctx>> {
2552 let name = elmt.name().unwrap_or(name);
2553 match &expr.kind {
2554 AstKind::Binop(binop) => {
2555 let lhs =
2556 self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2557 let rhs =
2558 self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2559 match binop.op {
2560 '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2561 '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2562 '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2563 '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2564 unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2565 }
2566 }
2567 AstKind::Monop(monop) => {
2568 let child =
2569 self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2570 match monop.op {
2571 '-' => Ok(self.builder.build_float_neg(child, name)?),
2572 unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2573 }
2574 }
2575 AstKind::Call(call) => match self.get_function(call.fn_name) {
2576 Some(function) => {
2577 let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2578 for arg in call.args.iter() {
2579 let arg_val =
2580 self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2581 args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2582 }
2583 let ret_value = self
2584 .builder
2585 .build_call(function, args.as_slice(), name)?
2586 .try_as_basic_value()
2587 .unwrap_basic()
2588 .into_float_value();
2589 Ok(ret_value)
2590 }
2591 None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2592 },
2593 AstKind::CallArg(arg) => {
2594 self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2595 }
2596 AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2597 AstKind::Name(iname) => {
2598 let ptr = self.get_param(iname.name);
2599 let layout = self.layout.get_layout(iname.name).unwrap();
2600 let iname_elmt_index = if layout.is_dense() {
2601 let mut no_transform = true;
2603 let mut iname_index = Vec::new();
2604 for (i, c) in iname.indices.iter().enumerate() {
2605 let pi = elmt
2608 .indices()
2609 .iter()
2610 .position(|x| x == c)
2611 .unwrap_or(elmt.indices().len());
2612 if let Some(indice) =
2614 iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap())
2615 {
2616 let start = indice.first.as_ref().kind.as_integer().unwrap();
2617 let start_intval = self
2618 .context
2619 .i32_type()
2620 .const_int(start.try_into().unwrap(), false);
2621 let index_pi = if pi >= index.len() {
2623 self.context.i32_type().const_int(0, false)
2624 } else {
2625 index[pi]
2626 };
2627 let index_pi =
2628 self.builder.build_int_add(index_pi, start_intval, name)?;
2629 iname_index.push(index_pi);
2630 } else {
2631 iname_index.push(index[pi]);
2632 }
2633 no_transform = no_transform && pi == i;
2634 }
2635 let iname_index: Vec<IntValue> = iname_index
2637 .iter()
2638 .enumerate()
2639 .map(|(i, &idx)| {
2640 if layout.shape()[i] == 1 {
2641 self.context.i32_type().const_int(0, false)
2642 } else {
2643 idx
2644 }
2645 })
2646 .collect();
2647 if !iname_index.is_empty() {
2650 let mut iname_elmt_index = *iname_index.last().unwrap();
2652 let mut stride = 1u64;
2653 for i in (0..iname_index.len() - 1).rev() {
2654 let iname_i = iname_index[i];
2655 let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2656 stride *= shapei;
2657 let stride_intval = self.context.i32_type().const_int(stride, false);
2658 let stride_mul_i =
2659 self.builder.build_int_mul(stride_intval, iname_i, name)?;
2660 iname_elmt_index =
2661 self.builder
2662 .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2663 }
2664 iname_elmt_index
2665 } else {
2666 let zero = self.context.i32_type().const_int(0, false);
2668 zero
2669 }
2670 } else if layout.is_sparse() || layout.is_diagonal() {
2671 let expr_layout = elmt.expr_layout();
2672
2673 if expr_layout != layout {
2674 let permutation =
2681 DataLayout::permutation(elmt, iname.indices.as_slice(), layout);
2682 if let Some(base_binary_layout_index) =
2683 self.layout
2684 .get_binary_layout_index(layout, expr_layout, permutation)
2685 {
2686 let binary_layout_index = self.builder.build_int_add(
2687 self.int_type
2688 .const_int(base_binary_layout_index.try_into().unwrap(), false),
2689 expr_index,
2690 name,
2691 )?;
2692
2693 let indices_ptr = Self::get_ptr_to_index(
2694 &self.builder,
2695 self.int_type,
2696 self.get_param("indices"),
2697 binary_layout_index,
2698 name,
2699 );
2700 let mapped_index = self
2701 .build_load(self.int_type, indices_ptr, name)?
2702 .into_int_value();
2703 let is_less_than_zero = self.builder.build_int_compare(
2704 IntPredicate::SLT,
2705 mapped_index,
2706 self.int_type.const_int(0, true),
2707 "sparse_index_check",
2708 )?;
2709 let is_less_than_zero_block =
2710 self.context.append_basic_block(self.fn_value(), "lt_zero");
2711 let not_less_than_zero_block = self
2712 .context
2713 .append_basic_block(self.fn_value(), "not_lt_zero");
2714 let merge_block =
2715 self.context.append_basic_block(self.fn_value(), "merge");
2716 self.builder.build_conditional_branch(
2717 is_less_than_zero,
2718 is_less_than_zero_block,
2719 not_less_than_zero_block,
2720 )?;
2721
2722 self.builder.position_at_end(is_less_than_zero_block);
2724 let zero_value = self.real_type.const_float(0.);
2725 self.builder.build_unconditional_branch(merge_block)?;
2726
2727 self.builder.position_at_end(not_less_than_zero_block);
2729 let value_ptr = Self::get_ptr_to_index(
2730 &self.builder,
2731 self.real_type,
2732 ptr,
2733 mapped_index,
2734 name,
2735 );
2736 let value = self
2737 .build_load(self.real_type, value_ptr, name)?
2738 .into_float_value();
2739 self.builder.build_unconditional_branch(merge_block)?;
2740
2741 self.builder.position_at_end(merge_block);
2743 let if_return_value =
2744 self.builder.build_phi(self.real_type, "sparse_value")?;
2745 if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
2746 if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
2747
2748 let phi_value = if_return_value.as_basic_value().into_float_value();
2749 return Ok(phi_value);
2750 } else {
2751 expr_index
2752 }
2753 } else {
2754 expr_index
2756 }
2757 } else {
2758 panic!("unexpected layout");
2759 };
2760 let value_ptr = Self::get_ptr_to_index(
2761 &self.builder,
2762 self.real_type,
2763 ptr,
2764 iname_elmt_index,
2765 name,
2766 );
2767 Ok(self
2768 .build_load(self.real_type, value_ptr, name)?
2769 .into_float_value())
2770 }
2771 AstKind::NamedGradient(name) => {
2772 let name_str = name.to_string();
2773 let ptr = self.get_param(name_str.as_str());
2774 Ok(self
2775 .build_load(self.real_type, *ptr, name_str.as_str())?
2776 .into_float_value())
2777 }
2778 AstKind::Index(_) => todo!(),
2779 AstKind::Slice(_) => todo!(),
2780 AstKind::Integer(_) => todo!(),
2781 _ => panic!("unexprected astkind"),
2782 }
2783 }
2784
2785 fn clear(&mut self) {
2786 self.variables.clear();
2787 self.fn_value_opt = None;
2789 self.tensor_ptr_opt = None;
2790 }
2791
2792 fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2793 match arg {
2794 BasicValueEnum::PointerValue(v) => v,
2795 BasicValueEnum::FloatValue(v) => {
2796 let alloca = self
2797 .create_entry_block_builder()
2798 .build_alloca(arg.get_type(), name)
2799 .unwrap();
2800 self.builder.build_store(alloca, v).unwrap();
2801 alloca
2802 }
2803 BasicValueEnum::IntValue(v) => {
2804 let alloca = self
2805 .create_entry_block_builder()
2806 .build_alloca(arg.get_type(), name)
2807 .unwrap();
2808 self.builder.build_store(alloca, v).unwrap();
2809 alloca
2810 }
2811 _ => unreachable!(),
2812 }
2813 }
2814
2815 pub fn compile_set_u0<'m>(
2816 &mut self,
2817 model: &'m DiscreteModel,
2818 code: Option<&str>,
2819 ) -> Result<FunctionValue<'ctx>> {
2820 self.clear();
2821 let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2822 let function = self.add_function(
2823 "set_u0",
2824 fn_arg_names,
2825 &[
2826 self.real_ptr_type.into(),
2827 self.real_ptr_type.into(),
2828 self.int_type.into(),
2829 self.int_type.into(),
2830 ],
2831 None,
2832 false,
2833 );
2834
2835 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2837 let noalign = self.context.create_enum_attribute(alias_id, 0);
2838 for i in &[0, 1] {
2839 function.add_attribute(AttributeLoc::Param(*i), noalign);
2840 }
2841
2842 let _basic_block = self.start_function(function, code);
2843
2844 for (i, arg) in function.get_param_iter().enumerate() {
2845 let name = fn_arg_names[i];
2846 let alloca = self.function_arg_alloca(name, arg);
2847 self.insert_param(name, alloca);
2848 }
2849
2850 self.insert_data(model);
2851 self.insert_indices();
2852
2853 let mut nbarriers = 0;
2854 let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2855 #[allow(clippy::explicit_counter_loop)]
2856 for a in model.input_dep_defns() {
2857 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
2858 self.jit_compile_call_barrier(nbarriers, total_barriers);
2859 nbarriers += 1;
2860 }
2861
2862 self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")), code)?;
2863 self.jit_compile_call_barrier(nbarriers, total_barriers);
2864
2865 self.builder.build_return(None)?;
2866
2867 if function.verify(true) {
2868 Ok(function)
2869 } else {
2870 function.print_to_stderr();
2871 unsafe {
2872 function.delete();
2873 }
2874 Err(anyhow!("Invalid generated function."))
2875 }
2876 }
2877
2878 pub fn compile_calc_out<'m>(
2879 &mut self,
2880 model: &'m DiscreteModel,
2881 include_constants: bool,
2882 code: Option<&str>,
2883 ) -> Result<FunctionValue<'ctx>> {
2884 self.clear();
2885 let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2886 let function_name = if include_constants {
2887 "calc_out_full"
2888 } else {
2889 "calc_out"
2890 };
2891 let function = self.add_function(
2892 function_name,
2893 fn_arg_names,
2894 &[
2895 self.real_type.into(),
2896 self.real_ptr_type.into(),
2897 self.real_ptr_type.into(),
2898 self.real_ptr_type.into(),
2899 self.int_type.into(),
2900 self.int_type.into(),
2901 ],
2902 None,
2903 false,
2904 );
2905
2906 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2908 let noalign = self.context.create_enum_attribute(alias_id, 0);
2909 for i in &[1, 2] {
2910 function.add_attribute(AttributeLoc::Param(*i), noalign);
2911 }
2912
2913 let _basic_block = self.start_function(function, code);
2914
2915 for (i, arg) in function.get_param_iter().enumerate() {
2916 let name = fn_arg_names[i];
2917 let alloca = self.function_arg_alloca(name, arg);
2918 self.insert_param(name, alloca);
2919 }
2920
2921 self.insert_state(model.state());
2922 self.insert_data(model);
2923 self.insert_indices();
2924
2925 if let Some(out) = model.out() {
2931 let mut nbarriers = 0;
2932 let mut total_barriers =
2933 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2934 if include_constants {
2935 total_barriers += model.input_dep_defns().len() as u64;
2936 for tensor in model.input_dep_defns() {
2938 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
2939 self.jit_compile_call_barrier(nbarriers, total_barriers);
2940 nbarriers += 1;
2941 }
2942 }
2943
2944 for tensor in model.time_dep_defns() {
2946 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
2947 self.jit_compile_call_barrier(nbarriers, total_barriers);
2948 nbarriers += 1;
2949 }
2950
2951 #[allow(clippy::explicit_counter_loop)]
2953 for a in model.state_dep_defns() {
2954 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
2955 self.jit_compile_call_barrier(nbarriers, total_barriers);
2956 nbarriers += 1;
2957 }
2958
2959 self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())), code)?;
2960 self.jit_compile_call_barrier(nbarriers, total_barriers);
2961 }
2962 self.builder.build_return(None)?;
2963
2964 if function.verify(true) {
2965 Ok(function)
2966 } else {
2967 function.print_to_stderr();
2968 unsafe {
2969 function.delete();
2970 }
2971 Err(anyhow!("Invalid generated function."))
2972 }
2973 }
2974
2975 pub fn compile_calc_stop<'m>(
2976 &mut self,
2977 model: &'m DiscreteModel,
2978 code: Option<&str>,
2979 ) -> Result<FunctionValue<'ctx>> {
2980 self.clear();
2981 let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2982 let function = self.add_function(
2983 "calc_stop",
2984 fn_arg_names,
2985 &[
2986 self.real_type.into(),
2987 self.real_ptr_type.into(),
2988 self.real_ptr_type.into(),
2989 self.real_ptr_type.into(),
2990 self.int_type.into(),
2991 self.int_type.into(),
2992 ],
2993 None,
2994 false,
2995 );
2996
2997 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2999 let noalign = self.context.create_enum_attribute(alias_id, 0);
3000 for i in &[1, 2, 3] {
3001 function.add_attribute(AttributeLoc::Param(*i), noalign);
3002 }
3003
3004 let _basic_block = self.start_function(function, code);
3005
3006 for (i, arg) in function.get_param_iter().enumerate() {
3007 let name = fn_arg_names[i];
3008 let alloca = self.function_arg_alloca(name, arg);
3009 self.insert_param(name, alloca);
3010 }
3011
3012 self.insert_state(model.state());
3013 self.insert_data(model);
3014 self.insert_indices();
3015
3016 if let Some(stop) = model.stop() {
3017 let mut nbarriers = 0;
3019 let total_barriers =
3020 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
3021 for tensor in model.time_dep_defns() {
3022 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3023 self.jit_compile_call_barrier(nbarriers, total_barriers);
3024 nbarriers += 1;
3025 }
3026
3027 for a in model.state_dep_defns() {
3029 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3030 self.jit_compile_call_barrier(nbarriers, total_barriers);
3031 nbarriers += 1;
3032 }
3033
3034 let res_ptr = self.get_param("root");
3035 self.jit_compile_tensor(stop, Some(*res_ptr), code)?;
3036 self.jit_compile_call_barrier(nbarriers, total_barriers);
3037 }
3038 self.builder.build_return(None)?;
3039
3040 if function.verify(true) {
3041 Ok(function)
3042 } else {
3043 function.print_to_stderr();
3044 unsafe {
3045 function.delete();
3046 }
3047 Err(anyhow!("Invalid generated function."))
3048 }
3049 }
3050
3051 pub fn compile_rhs<'m>(
3052 &mut self,
3053 model: &'m DiscreteModel,
3054 include_constants: bool,
3055 code: Option<&str>,
3056 ) -> Result<FunctionValue<'ctx>> {
3057 self.clear();
3058 let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
3059 let function_name = if include_constants { "rhs_full" } else { "rhs" };
3060 let function = self.add_function(
3061 function_name,
3062 fn_arg_names,
3063 &[
3064 self.real_type.into(),
3065 self.real_ptr_type.into(),
3066 self.real_ptr_type.into(),
3067 self.real_ptr_type.into(),
3068 self.int_type.into(),
3069 self.int_type.into(),
3070 ],
3071 None,
3072 false,
3073 );
3074
3075 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3077 let noalign = self.context.create_enum_attribute(alias_id, 0);
3078 for i in &[1, 2, 3] {
3079 function.add_attribute(AttributeLoc::Param(*i), noalign);
3080 }
3081
3082 let _basic_block = self.start_function(function, code);
3083
3084 for (i, arg) in function.get_param_iter().enumerate() {
3085 let name = fn_arg_names[i];
3086 let alloca = self.function_arg_alloca(name, arg);
3087 self.insert_param(name, alloca);
3088 }
3089
3090 self.insert_state(model.state());
3091 self.insert_data(model);
3092 self.insert_indices();
3093
3094 let mut nbarriers = 0;
3095 let mut total_barriers =
3096 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
3097 if include_constants {
3098 total_barriers += model.input_dep_defns().len() as u64;
3099 for tensor in model.input_dep_defns() {
3101 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3102 self.jit_compile_call_barrier(nbarriers, total_barriers);
3103 nbarriers += 1;
3104 }
3105 }
3106
3107 for tensor in model.time_dep_defns() {
3109 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3110 self.jit_compile_call_barrier(nbarriers, total_barriers);
3111 nbarriers += 1;
3112 }
3113
3114 for a in model.state_dep_defns() {
3116 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3117 self.jit_compile_call_barrier(nbarriers, total_barriers);
3118 nbarriers += 1;
3119 }
3120
3121 let res_ptr = self.get_param("rr");
3123 self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?;
3124 self.jit_compile_call_barrier(nbarriers, total_barriers);
3125
3126 self.builder.build_return(None)?;
3127
3128 if function.verify(true) {
3129 Ok(function)
3130 } else {
3131 function.print_to_stderr();
3132 unsafe {
3133 function.delete();
3134 }
3135 Err(anyhow!("Invalid generated function."))
3136 }
3137 }
3138
3139 pub fn compile_mass<'m>(
3140 &mut self,
3141 model: &'m DiscreteModel,
3142 code: Option<&str>,
3143 ) -> Result<FunctionValue<'ctx>> {
3144 self.clear();
3145 let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
3146 let function = self.add_function(
3147 "mass",
3148 fn_arg_names,
3149 &[
3150 self.real_type.into(),
3151 self.real_ptr_type.into(),
3152 self.real_ptr_type.into(),
3153 self.real_ptr_type.into(),
3154 self.int_type.into(),
3155 self.int_type.into(),
3156 ],
3157 None,
3158 false,
3159 );
3160
3161 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3163 let noalign = self.context.create_enum_attribute(alias_id, 0);
3164 for i in &[1, 2, 3] {
3165 function.add_attribute(AttributeLoc::Param(*i), noalign);
3166 }
3167
3168 let _basic_block = self.start_function(function, code);
3169
3170 for (i, arg) in function.get_param_iter().enumerate() {
3171 let name = fn_arg_names[i];
3172 let alloca = self.function_arg_alloca(name, arg);
3173 self.insert_param(name, alloca);
3174 }
3175
3176 if model.state_dot().is_some() && model.lhs().is_some() {
3178 let state_dot = model.state_dot().unwrap();
3179 let lhs = model.lhs().unwrap();
3180
3181 self.insert_dot_state(state_dot);
3182 self.insert_data(model);
3183 self.insert_indices();
3184
3185 let mut nbarriers = 0;
3187 let total_barriers =
3188 (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
3189 for tensor in model.time_dep_defns() {
3190 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3191 self.jit_compile_call_barrier(nbarriers, total_barriers);
3192 nbarriers += 1;
3193 }
3194
3195 for a in model.dstate_dep_defns() {
3196 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3197 self.jit_compile_call_barrier(nbarriers, total_barriers);
3198 nbarriers += 1;
3199 }
3200
3201 let res_ptr = self.get_param("rr");
3203 self.jit_compile_tensor(lhs, Some(*res_ptr), code)?;
3204 self.jit_compile_call_barrier(nbarriers, total_barriers);
3205 }
3206
3207 self.builder.build_return(None)?;
3208
3209 if function.verify(true) {
3210 Ok(function)
3211 } else {
3212 function.print_to_stderr();
3213 unsafe {
3214 function.delete();
3215 }
3216 Err(anyhow!("Invalid generated function."))
3217 }
3218 }
3219
3220 pub fn compile_gradient(
3221 &mut self,
3222 original_function: FunctionValue<'ctx>,
3223 args_type: &[CompileGradientArgType],
3224 mode: CompileMode,
3225 fn_name: &str,
3226 ) -> Result<FunctionValue<'ctx>> {
3227 self.clear();
3228
3229 let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
3231
3232 let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
3233
3234 let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
3235 let mut start_param_index: Vec<u32> = Vec::new();
3236 let mut ptr_arg_indices: Vec<u32> = Vec::new();
3237 for (i, arg) in original_function.get_param_iter().enumerate() {
3238 start_param_index.push(u32::try_from(fn_type.len()).unwrap());
3239 let arg_type = arg.get_type();
3240 fn_type.push(arg_type.into());
3241
3242 enzyme_fn_type.push(self.int_type.into());
3244 enzyme_fn_type.push(arg.get_type().into());
3245
3246 if arg_type.is_pointer_type() {
3247 ptr_arg_indices.push(u32::try_from(i).unwrap());
3248 }
3249
3250 match args_type[i] {
3251 CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
3252 fn_type.push(arg.get_type().into());
3253 enzyme_fn_type.push(arg.get_type().into());
3254 }
3255 CompileGradientArgType::Const => {}
3256 }
3257 }
3258 let fn_arg_names = fn_type
3259 .iter()
3260 .enumerate()
3261 .map(|(i, _)| format!("arg{}", i))
3262 .collect::<Vec<String>>();
3263 let fn_arg_names_ref = fn_arg_names
3264 .iter()
3265 .map(|s| s.as_str())
3266 .collect::<Vec<&str>>();
3267 let function = self.add_function(fn_name, &fn_arg_names_ref, &fn_type, None, false);
3268
3269 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3271 let noalign = self.context.create_enum_attribute(alias_id, 0);
3272 for i in ptr_arg_indices {
3273 function.add_attribute(AttributeLoc::Param(i), noalign);
3274 }
3275
3276 let _basic_block = self.start_function(function, None);
3277
3278 let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
3279 let mut input_activity = Vec::new();
3280 let mut arg_trees = Vec::new();
3281 for (i, arg) in original_function.get_param_iter().enumerate() {
3282 let param_index = start_param_index[i];
3283 let fn_arg = function.get_nth_param(param_index).unwrap();
3284
3285 let concrete_type = match arg.get_type() {
3288 BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
3289 BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
3290 RealType::F32 => CConcreteType_DT_Float,
3291 RealType::F64 => CConcreteType_DT_Double,
3292 },
3293 BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
3294 _ => panic!("unsupported type"),
3295 };
3296 let new_tree = unsafe {
3297 EnzymeNewTypeTreeCT(
3298 concrete_type,
3299 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3300 )
3301 };
3302 unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3303
3304 if concrete_type == CConcreteType_DT_Pointer {
3306 let inner_concrete_type = match self.diffsl_real_type {
3307 RealType::F32 => CConcreteType_DT_Float,
3308 RealType::F64 => CConcreteType_DT_Double,
3309 };
3310 let inner_new_tree = unsafe {
3311 EnzymeNewTypeTreeCT(
3312 inner_concrete_type,
3313 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3314 )
3315 };
3316 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3317 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3318 unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
3319 }
3320
3321 arg_trees.push(new_tree);
3322 match args_type[i] {
3323 CompileGradientArgType::Dup => {
3324 enzyme_fn_args.push(fn_arg.into());
3326
3327 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3329 enzyme_fn_args.push(fn_darg.into());
3330
3331 input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3332 }
3333 CompileGradientArgType::DupNoNeed => {
3334 enzyme_fn_args.push(fn_arg.into());
3336
3337 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3339 enzyme_fn_args.push(fn_darg.into());
3340
3341 input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3342 }
3343 CompileGradientArgType::Const => {
3344 enzyme_fn_args.push(fn_arg.into());
3346
3347 input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3348 }
3349 }
3350 }
3351 let ret_primary_ret = false;
3353 let diff_ret = false;
3354 let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3355 let ret_tree = unsafe {
3356 EnzymeNewTypeTreeCT(
3357 CConcreteType_DT_Anything,
3358 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3359 )
3360 };
3361
3362 let fnc_opt_base = true;
3364 let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3365
3366 let kv_tmp = IntList {
3367 data: std::ptr::null_mut(),
3368 size: 0,
3369 };
3370 let mut known_values = vec![kv_tmp; input_activity.len()];
3371
3372 let fn_type_info = CFnTypeInfo {
3373 Arguments: arg_trees.as_mut_ptr(),
3374 Return: ret_tree,
3375 KnownValues: known_values.as_mut_ptr(),
3376 };
3377
3378 let type_analysis: EnzymeTypeAnalysisRef =
3379 unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3380
3381 let mut args_uncacheable = vec![0; arg_trees.len()];
3382
3383 let enzyme_function = match mode {
3384 CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3385 EnzymeCreateForwardDiff(
3386 logic_ref, std::ptr::null_mut(),
3388 std::ptr::null_mut(),
3389 original_function.as_value_ref(),
3390 ret_activity, input_activity.as_mut_ptr(),
3392 input_activity.len(), type_analysis, ret_primary_ret as u8,
3395 CDerivativeMode_DEM_ForwardMode, 1, 0, 0, 1, std::ptr::null_mut(), fn_type_info, 1, args_uncacheable.as_mut_ptr(), args_uncacheable.len(), std::ptr::null_mut(), )
3407 },
3408 CompileMode::Reverse | CompileMode::ReverseSens => {
3409 let mut call_enzyme = || unsafe {
3410 EnzymeCreatePrimalAndGradient(
3411 logic_ref,
3412 std::ptr::null_mut(),
3413 std::ptr::null_mut(),
3414 original_function.as_value_ref(),
3415 ret_activity,
3416 input_activity.as_mut_ptr(),
3417 input_activity.len(),
3418 type_analysis,
3419 ret_primary_ret as u8,
3420 diff_ret as u8,
3421 CDerivativeMode_DEM_ReverseModeCombined,
3422 0,
3423 0, 1,
3425 1,
3426 std::ptr::null_mut(),
3427 0, fn_type_info,
3429 0, args_uncacheable.as_mut_ptr(),
3431 args_uncacheable.len(),
3432 std::ptr::null_mut(),
3433 if self.threaded { 1 } else { 0 }, )
3435 };
3436 if self.threaded {
3437 let _lock = my_mutex.lock().unwrap();
3439 let barrier_string = CString::new("barrier").unwrap();
3440 unsafe {
3441 EnzymeRegisterCallHandler(
3442 barrier_string.as_ptr(),
3443 Some(fwd_handler),
3444 Some(rev_handler),
3445 )
3446 };
3447 let ret = call_enzyme();
3448 unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3450 ret
3451 } else {
3452 call_enzyme()
3453 }
3454 }
3455 };
3456
3457 unsafe { FreeEnzymeLogic(logic_ref) };
3459 unsafe { FreeTypeAnalysis(type_analysis) };
3460 unsafe { EnzymeFreeTypeTree(ret_tree) };
3461 for tree in arg_trees {
3462 unsafe { EnzymeFreeTypeTree(tree) };
3463 }
3464
3465 let enzyme_function =
3467 unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3468 self.builder
3469 .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3470
3471 self.builder.build_return(None)?;
3473
3474 if function.verify(true) {
3475 Ok(function)
3476 } else {
3477 function.print_to_stderr();
3478 enzyme_function.print_to_stderr();
3479 unsafe {
3480 function.delete();
3481 }
3482 Err(anyhow!("Invalid generated function."))
3483 }
3484 }
3485
3486 pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3487 self.clear();
3488 let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3489 let function = self.add_function(
3490 "get_dims",
3491 fn_arg_names,
3492 &[
3493 self.int_ptr_type.into(),
3494 self.int_ptr_type.into(),
3495 self.int_ptr_type.into(),
3496 self.int_ptr_type.into(),
3497 self.int_ptr_type.into(),
3498 self.int_ptr_type.into(),
3499 ],
3500 None,
3501 false,
3502 );
3503 let _block = self.start_function(function, None);
3504
3505 for (i, arg) in function.get_param_iter().enumerate() {
3506 let name = fn_arg_names[i];
3507 let alloca = self.function_arg_alloca(name, arg);
3508 self.insert_param(name, alloca);
3509 }
3510
3511 self.insert_indices();
3512
3513 let number_of_states = model.state().nnz() as u64;
3514 let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
3515 let number_of_outputs = match model.out() {
3516 Some(out) => out.nnz() as u64,
3517 None => 0,
3518 };
3519 let number_of_stop = if let Some(stop) = model.stop() {
3520 stop.nnz() as u64
3521 } else {
3522 0
3523 };
3524 let has_mass = match model.lhs().is_some() {
3525 true => 1u64,
3526 false => 0u64,
3527 };
3528 let data_len = self.layout.data().len() as u64;
3529 self.builder.build_store(
3530 *self.get_param("states"),
3531 self.int_type.const_int(number_of_states, false),
3532 )?;
3533 self.builder.build_store(
3534 *self.get_param("inputs"),
3535 self.int_type.const_int(number_of_inputs, false),
3536 )?;
3537 self.builder.build_store(
3538 *self.get_param("outputs"),
3539 self.int_type.const_int(number_of_outputs, false),
3540 )?;
3541 self.builder.build_store(
3542 *self.get_param("data"),
3543 self.int_type.const_int(data_len, false),
3544 )?;
3545 self.builder.build_store(
3546 *self.get_param("stop"),
3547 self.int_type.const_int(number_of_stop, false),
3548 )?;
3549 self.builder.build_store(
3550 *self.get_param("has_mass"),
3551 self.int_type.const_int(has_mass, false),
3552 )?;
3553 self.builder.build_return(None)?;
3554
3555 if function.verify(true) {
3556 Ok(function)
3557 } else {
3558 function.print_to_stderr();
3559 unsafe {
3560 function.delete();
3561 }
3562 Err(anyhow!("Invalid generated function."))
3563 }
3564 }
3565
3566 pub fn compile_get_tensor(
3567 &mut self,
3568 model: &DiscreteModel,
3569 name: &str,
3570 ) -> Result<FunctionValue<'ctx>> {
3571 self.clear();
3572 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3573 let function_name = format!("get_tensor_{name}");
3574 let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3575 let function = self.add_function(
3576 function_name.as_str(),
3577 fn_arg_names,
3578 &[
3579 self.real_ptr_type.into(),
3580 real_ptr_ptr_type.into(),
3581 self.int_ptr_type.into(),
3582 ],
3583 None,
3584 false,
3585 );
3586 let _basic_block = self.start_function(function, None);
3587
3588 for (i, arg) in function.get_param_iter().enumerate() {
3589 let name = fn_arg_names[i];
3590 let alloca = self.function_arg_alloca(name, arg);
3591 self.insert_param(name, alloca);
3592 }
3593
3594 self.insert_data(model);
3595 let ptr = self.get_param(name);
3596 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3597 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3598 self.builder
3599 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3600 self.builder
3601 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3602 self.builder.build_return(None)?;
3603
3604 if function.verify(true) {
3605 Ok(function)
3606 } else {
3607 function.print_to_stderr();
3608 unsafe {
3609 function.delete();
3610 }
3611 Err(anyhow!("Invalid generated function."))
3612 }
3613 }
3614
3615 pub fn compile_get_constant(
3616 &mut self,
3617 model: &DiscreteModel,
3618 name: &str,
3619 ) -> Result<FunctionValue<'ctx>> {
3620 self.clear();
3621 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3622 let function_name = format!("get_constant_{name}");
3623 let fn_arg_names = &["tensor_data", "tensor_size"];
3624 let function = self.add_function(
3625 function_name.as_str(),
3626 fn_arg_names,
3627 &[real_ptr_ptr_type.into(), self.int_ptr_type.into()],
3628 None,
3629 false,
3630 );
3631 let _basic_block = self.start_function(function, None);
3632
3633 for (i, arg) in function.get_param_iter().enumerate() {
3634 let name = fn_arg_names[i];
3635 let alloca = self.function_arg_alloca(name, arg);
3636 self.insert_param(name, alloca);
3637 }
3638
3639 self.insert_constants(model);
3640 let ptr = self.get_param(name);
3641 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3642 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3643 self.builder
3644 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3645 self.builder
3646 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3647 self.builder.build_return(None)?;
3648
3649 if function.verify(true) {
3650 Ok(function)
3651 } else {
3652 function.print_to_stderr();
3653 unsafe {
3654 function.delete();
3655 }
3656 Err(anyhow!("Invalid generated function."))
3657 }
3658 }
3659
3660 pub fn compile_inputs(
3661 &mut self,
3662 model: &DiscreteModel,
3663 is_get: bool,
3664 ) -> Result<FunctionValue<'ctx>> {
3665 self.clear();
3666 let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3667 let fn_arg_names = &["inputs", "data"];
3668 let function = self.add_function(
3669 function_name,
3670 fn_arg_names,
3671 &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3672 None,
3673 false,
3674 );
3675 let block = self.start_function(function, None);
3676
3677 for (i, arg) in function.get_param_iter().enumerate() {
3678 let name = fn_arg_names[i];
3679 let alloca = self.function_arg_alloca(name, arg);
3680 self.insert_param(name, alloca);
3681 }
3682
3683 if let Some(input) = model.input() {
3684 let name = input.name();
3685 self.insert_tensor(input, false);
3686 let ptr = self.get_var(input);
3687 let inputs_start_index = self.int_type.const_int(0, false);
3689 let start_index = self.int_type.const_int(0, false);
3690 let end_index = self
3691 .int_type
3692 .const_int(input.nnz().try_into().unwrap(), false);
3693
3694 let input_block = self.context.append_basic_block(function, name);
3695 self.builder.build_unconditional_branch(input_block)?;
3696 self.builder.position_at_end(input_block);
3697 let index = self.builder.build_phi(self.int_type, "i")?;
3698 index.add_incoming(&[(&start_index, block)]);
3699
3700 let curr_input_index = index.as_basic_value().into_int_value();
3702 let input_ptr =
3703 Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
3704 let curr_inputs_index =
3705 self.builder
3706 .build_int_add(inputs_start_index, curr_input_index, name)?;
3707 let inputs_ptr = Self::get_ptr_to_index(
3708 &self.builder,
3709 self.real_type,
3710 self.get_param("inputs"),
3711 curr_inputs_index,
3712 name,
3713 );
3714 if is_get {
3715 let input_value = self
3716 .build_load(self.real_type, input_ptr, name)?
3717 .into_float_value();
3718 self.builder.build_store(inputs_ptr, input_value)?;
3719 } else {
3720 let input_value = self
3721 .build_load(self.real_type, inputs_ptr, name)?
3722 .into_float_value();
3723 self.builder.build_store(input_ptr, input_value)?;
3724 }
3725
3726 let one = self.int_type.const_int(1, false);
3728 let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
3729 index.add_incoming(&[(&next_index, input_block)]);
3730
3731 let loop_while =
3733 self.builder
3734 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
3735 let post_block = self.context.append_basic_block(function, name);
3736 self.builder
3737 .build_conditional_branch(loop_while, input_block, post_block)?;
3738 self.builder.position_at_end(post_block);
3739 }
3740 self.builder.build_return(None)?;
3741
3742 if function.verify(true) {
3743 Ok(function)
3744 } else {
3745 function.print_to_stderr();
3746 unsafe {
3747 function.delete();
3748 }
3749 Err(anyhow!("Invalid generated function."))
3750 }
3751 }
3752
3753 pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3754 self.clear();
3755 let fn_arg_names = &["id"];
3756 let function = self.add_function(
3757 "set_id",
3758 fn_arg_names,
3759 &[self.real_ptr_type.into()],
3760 None,
3761 false,
3762 );
3763 let mut block = self.start_function(function, None);
3764
3765 for (i, arg) in function.get_param_iter().enumerate() {
3766 let name = fn_arg_names[i];
3767 let alloca = self.function_arg_alloca(name, arg);
3768 self.insert_param(name, alloca);
3769 }
3770
3771 let mut id_index = 0usize;
3772 for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3773 let name = blk.name().unwrap_or("unknown");
3774 let id_start_index = self.int_type.const_int(id_index as u64, false);
3776 let blk_start_index = self.int_type.const_int(0, false);
3777 let blk_end_index = self
3778 .int_type
3779 .const_int(blk.nnz().try_into().unwrap(), false);
3780
3781 let blk_block = self.context.append_basic_block(function, name);
3782 self.builder.build_unconditional_branch(blk_block)?;
3783 self.builder.position_at_end(blk_block);
3784 let index = self.builder.build_phi(self.int_type, "i")?;
3785 index.add_incoming(&[(&blk_start_index, block)]);
3786
3787 let curr_blk_index = index.as_basic_value().into_int_value();
3789 let curr_id_index = self
3790 .builder
3791 .build_int_add(id_start_index, curr_blk_index, name)?;
3792 let id_ptr = Self::get_ptr_to_index(
3793 &self.builder,
3794 self.real_type,
3795 self.get_param("id"),
3796 curr_id_index,
3797 name,
3798 );
3799 let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
3800 let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3801 self.builder.build_store(id_ptr, is_algebraic_value)?;
3802
3803 let one = self.int_type.const_int(1, false);
3805 let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3806 index.add_incoming(&[(&next_index, blk_block)]);
3807
3808 let loop_while = self.builder.build_int_compare(
3810 IntPredicate::ULT,
3811 next_index,
3812 blk_end_index,
3813 name,
3814 )?;
3815 let post_block = self.context.append_basic_block(function, name);
3816 self.builder
3817 .build_conditional_branch(loop_while, blk_block, post_block)?;
3818 self.builder.position_at_end(post_block);
3819
3820 block = post_block;
3822 id_index += blk.nnz();
3823 }
3824 self.builder.build_return(None)?;
3825
3826 if function.verify(true) {
3827 Ok(function)
3828 } else {
3829 function.print_to_stderr();
3830 unsafe {
3831 function.delete();
3832 }
3833 Err(anyhow!("Invalid generated function."))
3834 }
3835 }
3836
3837 pub fn module(&self) -> &Module<'ctx> {
3838 &self.module
3839 }
3840}