1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use inkwell::attributes::{Attribute, AttributeLoc};
4use inkwell::basic_block::BasicBlock;
5use inkwell::builder::Builder;
6use inkwell::context::{AsContextRef, Context};
7use inkwell::execution_engine::ExecutionEngine;
8use inkwell::intrinsics::Intrinsic;
9use inkwell::module::{Linkage, Module};
10use inkwell::passes::PassBuilderOptions;
11use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
12use inkwell::types::{
13 BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
14};
15use inkwell::values::{
16 AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
17 FunctionValue, GlobalValue, IntValue, PointerValue,
18};
19use inkwell::{
20 AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
21 OptimizationLevel,
22};
23use llvm_sys::core::{
24 LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
25 LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
26};
27use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
28use std::collections::HashMap;
29use std::ffi::CString;
30use std::iter::zip;
31use std::pin::Pin;
32use target_lexicon::Triple;
33
34type RealType = f64;
35
36use crate::ast::{Ast, AstKind};
37use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
38use crate::enzyme::{
39 CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Integer,
40 CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
41 CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
42 DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
43 EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
44 EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
45 FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
46 CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
47};
48use crate::execution::compiler::CompilerMode;
49use crate::execution::module::{
50 CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
51};
52use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
53use lazy_static::lazy_static;
54use std::sync::Mutex;
55
56lazy_static! {
57 static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
58}
59
60struct ImmovableLlvmModule {
61 codegen: Option<CodeGen<'static>>,
64 context: AliasableBox<Context>,
66 _pin: std::marker::PhantomPinned,
67}
68
69pub struct LlvmModule {
70 inner: Pin<Box<ImmovableLlvmModule>>,
71 machine: TargetMachine,
72}
73
74unsafe impl Send for LlvmModule {}
75unsafe impl Sync for LlvmModule {}
76
77impl LlvmModule {
78 fn new(triple: Option<Triple>, model: &DiscreteModel, threaded: bool) -> Result<Self> {
79 let initialization_config = &InitializationConfig::default();
80 Target::initialize_all(initialization_config);
81 let host_triple = Triple::host();
82 let (triple_str, native) = match triple {
83 Some(ref triple) => (triple.to_string(), false),
84 None => (host_triple.to_string(), true),
85 };
86 let triple = TargetTriple::create(triple_str.as_str());
87 let target = Target::from_triple(&triple).unwrap();
88 let cpu = if native {
89 TargetMachine::get_host_cpu_name().to_string()
90 } else {
91 "generic".to_string()
92 };
93 let features = if native {
94 TargetMachine::get_host_cpu_features().to_string()
95 } else {
96 "".to_string()
97 };
98 let machine = target
99 .create_target_machine(
100 &triple,
101 cpu.as_str(),
102 features.as_str(),
103 inkwell::OptimizationLevel::Aggressive,
104 inkwell::targets::RelocMode::Default,
105 inkwell::targets::CodeModel::Default,
106 )
107 .unwrap();
108
109 let context = AliasableBox::from_unique(Box::new(Context::create()));
110 let mut pinned = Self {
111 inner: Box::pin(ImmovableLlvmModule {
112 codegen: None,
113 context,
114 _pin: std::marker::PhantomPinned,
115 }),
116 machine,
117 };
118
119 let context_ref = pinned.inner.context.as_ref();
120 let real_type_str = "f64";
121 let codegen = CodeGen::new(
122 model,
123 context_ref,
124 context_ref.f64_type(),
125 context_ref.i32_type(),
126 real_type_str,
127 threaded,
128 )?;
129 let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
130 unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
131 Ok(pinned)
132 }
133
134 fn pre_autodiff_optimisation(&mut self) -> Result<()> {
135 let pass_options = PassBuilderOptions::create();
144 pass_options.set_loop_vectorization(false);
148 pass_options.set_loop_slp_vectorization(false);
149 pass_options.set_loop_unrolling(false);
150 let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
158 let (codegen, machine) = self.codegen_and_machine_mut();
159 codegen
160 .module()
161 .run_passes(passes, machine, pass_options)
162 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
163 }
164
165 fn post_autodiff_optimisation(&mut self) -> Result<()> {
166 if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
168 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
169 barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
170 }
171
172 for f in self.codegen_mut().module.get_functions() {
174 if f.get_name().to_str().unwrap().starts_with("preprocess_") {
175 unsafe { f.delete() };
176 }
177 }
178
179 let passes = "default<O3>";
185 let (codegen, machine) = self.codegen_and_machine_mut();
186 codegen
187 .module()
188 .run_passes(passes, machine, PassBuilderOptions::create())
189 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
190
191 Ok(())
192 }
193
194 pub fn print(&self) {
195 self.codegen().module().print_to_stderr();
196 }
197 fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
198 unsafe {
199 self.inner
200 .as_mut()
201 .get_unchecked_mut()
202 .codegen
203 .as_mut()
204 .unwrap()
205 }
206 }
207 fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
208 (
209 unsafe {
210 self.inner
211 .as_mut()
212 .get_unchecked_mut()
213 .codegen
214 .as_mut()
215 .unwrap()
216 },
217 &self.machine,
218 )
219 }
220
221 fn codegen(&self) -> &CodeGen<'static> {
222 self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
223 }
224}
225
226impl CodegenModule for LlvmModule {}
227
228impl CodegenModuleCompile for LlvmModule {
229 fn from_discrete_model(
230 model: &DiscreteModel,
231 mode: CompilerMode,
232 triple: Option<Triple>,
233 ) -> Result<Self> {
234 let thread_dim = mode.thread_dim(model.state().nnz());
235 let threaded = thread_dim > 1;
236 if (unsafe { LLVMIsMultithreaded() } <= 0) {
237 return Err(anyhow!(
238 "LLVM is not compiled with multithreading support, but this codegen module requires it."
239 ));
240 }
241
242 let mut module = Self::new(triple, model, threaded)?;
243
244 let set_u0 = module.codegen_mut().compile_set_u0(model)?;
245 let _calc_stop = module.codegen_mut().compile_calc_stop(model)?;
246 let rhs = module.codegen_mut().compile_rhs(model, false)?;
247 let rhs_full = module.codegen_mut().compile_rhs(model, true)?;
248 let mass = module.codegen_mut().compile_mass(model)?;
249 let calc_out = module.codegen_mut().compile_calc_out(model, false)?;
250 let calc_out_full = module.codegen_mut().compile_calc_out(model, true)?;
251 let _set_id = module.codegen_mut().compile_set_id(model)?;
252 let _get_dims = module.codegen_mut().compile_get_dims(model)?;
253 let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
254 let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
255 let _set_constants = module.codegen_mut().compile_set_constants(model)?;
256 let tensor_info = module
257 .codegen()
258 .layout
259 .tensors()
260 .map(|(name, is_constant)| (name.to_string(), is_constant))
261 .collect::<Vec<_>>();
262 for (tensor, is_constant) in tensor_info {
263 if is_constant {
264 module
265 .codegen_mut()
266 .compile_get_constant(model, tensor.as_str())?;
267 } else {
268 module
269 .codegen_mut()
270 .compile_get_tensor(model, tensor.as_str())?;
271 }
272 }
273
274 module.pre_autodiff_optimisation()?;
275
276 module.codegen_mut().compile_gradient(
277 set_u0,
278 &[
279 CompileGradientArgType::DupNoNeed,
280 CompileGradientArgType::DupNoNeed,
281 CompileGradientArgType::Const,
282 CompileGradientArgType::Const,
283 ],
284 CompileMode::Forward,
285 "set_u0_grad",
286 )?;
287
288 module.codegen_mut().compile_gradient(
289 rhs,
290 &[
291 CompileGradientArgType::Const,
292 CompileGradientArgType::DupNoNeed,
293 CompileGradientArgType::DupNoNeed,
294 CompileGradientArgType::DupNoNeed,
295 CompileGradientArgType::Const,
296 CompileGradientArgType::Const,
297 ],
298 CompileMode::Forward,
299 "rhs_grad",
300 )?;
301
302 module.codegen_mut().compile_gradient(
303 calc_out,
304 &[
305 CompileGradientArgType::Const,
306 CompileGradientArgType::DupNoNeed,
307 CompileGradientArgType::DupNoNeed,
308 CompileGradientArgType::DupNoNeed,
309 CompileGradientArgType::Const,
310 CompileGradientArgType::Const,
311 ],
312 CompileMode::Forward,
313 "calc_out_grad",
314 )?;
315 module.codegen_mut().compile_gradient(
316 set_inputs,
317 &[
318 CompileGradientArgType::DupNoNeed,
319 CompileGradientArgType::DupNoNeed,
320 ],
321 CompileMode::Forward,
322 "set_inputs_grad",
323 )?;
324
325 module.codegen_mut().compile_gradient(
326 set_u0,
327 &[
328 CompileGradientArgType::DupNoNeed,
329 CompileGradientArgType::DupNoNeed,
330 CompileGradientArgType::Const,
331 CompileGradientArgType::Const,
332 ],
333 CompileMode::Reverse,
334 "set_u0_rgrad",
335 )?;
336
337 module.codegen_mut().compile_gradient(
338 mass,
339 &[
340 CompileGradientArgType::Const,
341 CompileGradientArgType::DupNoNeed,
342 CompileGradientArgType::DupNoNeed,
343 CompileGradientArgType::DupNoNeed,
344 CompileGradientArgType::Const,
345 CompileGradientArgType::Const,
346 ],
347 CompileMode::Reverse,
348 "mass_rgrad",
349 )?;
350
351 module.codegen_mut().compile_gradient(
352 rhs,
353 &[
354 CompileGradientArgType::Const,
355 CompileGradientArgType::DupNoNeed,
356 CompileGradientArgType::DupNoNeed,
357 CompileGradientArgType::DupNoNeed,
358 CompileGradientArgType::Const,
359 CompileGradientArgType::Const,
360 ],
361 CompileMode::Reverse,
362 "rhs_rgrad",
363 )?;
364 module.codegen_mut().compile_gradient(
365 calc_out,
366 &[
367 CompileGradientArgType::Const,
368 CompileGradientArgType::DupNoNeed,
369 CompileGradientArgType::DupNoNeed,
370 CompileGradientArgType::DupNoNeed,
371 CompileGradientArgType::Const,
372 CompileGradientArgType::Const,
373 ],
374 CompileMode::Reverse,
375 "calc_out_rgrad",
376 )?;
377
378 module.codegen_mut().compile_gradient(
379 set_inputs,
380 &[
381 CompileGradientArgType::DupNoNeed,
382 CompileGradientArgType::DupNoNeed,
383 ],
384 CompileMode::Reverse,
385 "set_inputs_rgrad",
386 )?;
387
388 module.codegen_mut().compile_gradient(
389 rhs_full,
390 &[
391 CompileGradientArgType::Const,
392 CompileGradientArgType::Const,
393 CompileGradientArgType::DupNoNeed,
394 CompileGradientArgType::DupNoNeed,
395 CompileGradientArgType::Const,
396 CompileGradientArgType::Const,
397 ],
398 CompileMode::ForwardSens,
399 "rhs_sgrad",
400 )?;
401
402 module.codegen_mut().compile_gradient(
403 set_u0,
404 &[
405 CompileGradientArgType::DupNoNeed,
406 CompileGradientArgType::DupNoNeed,
407 CompileGradientArgType::Const,
408 CompileGradientArgType::Const,
409 ],
410 CompileMode::ForwardSens,
411 "set_u0_sgrad",
412 )?;
413
414 module.codegen_mut().compile_gradient(
415 calc_out_full,
416 &[
417 CompileGradientArgType::Const,
418 CompileGradientArgType::Const,
419 CompileGradientArgType::DupNoNeed,
420 CompileGradientArgType::DupNoNeed,
421 CompileGradientArgType::Const,
422 CompileGradientArgType::Const,
423 ],
424 CompileMode::ForwardSens,
425 "calc_out_sgrad",
426 )?;
427 module.codegen_mut().compile_gradient(
428 calc_out_full,
429 &[
430 CompileGradientArgType::Const,
431 CompileGradientArgType::Const,
432 CompileGradientArgType::DupNoNeed,
433 CompileGradientArgType::DupNoNeed,
434 CompileGradientArgType::Const,
435 CompileGradientArgType::Const,
436 ],
437 CompileMode::ReverseSens,
438 "calc_out_srgrad",
439 )?;
440
441 module.codegen_mut().compile_gradient(
442 rhs_full,
443 &[
444 CompileGradientArgType::Const,
445 CompileGradientArgType::Const,
446 CompileGradientArgType::DupNoNeed,
447 CompileGradientArgType::DupNoNeed,
448 CompileGradientArgType::Const,
449 CompileGradientArgType::Const,
450 ],
451 CompileMode::ReverseSens,
452 "rhs_srgrad",
453 )?;
454
455 module.post_autodiff_optimisation()?;
456 Ok(module)
457 }
458}
459
460impl CodegenModuleEmit for LlvmModule {
461 fn to_object(self) -> Result<Vec<u8>> {
462 let module = self.codegen().module();
463 let buffer = self
465 .machine
466 .write_to_memory_buffer(module, FileType::Object)
467 .unwrap()
468 .as_slice()
469 .to_vec();
470 Ok(buffer)
471 }
472}
473impl CodegenModuleJit for LlvmModule {
474 fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
475 let ee = self
476 .codegen()
477 .module()
478 .create_jit_execution_engine(OptimizationLevel::Aggressive)
479 .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
480
481 let module = self.codegen().module();
482 let mut symbols = HashMap::new();
483 for function in module.get_functions() {
484 let name = function.get_name().to_str().unwrap();
485 let address = ee.get_function_address(name);
486 if let Ok(address) = address {
487 symbols.insert(name.to_string(), address as *const u8);
488 }
489 }
490 Ok(symbols)
491 }
492}
493
494struct Globals<'ctx> {
495 indices: Option<GlobalValue<'ctx>>,
496 constants: Option<GlobalValue<'ctx>>,
497 thread_counter: Option<GlobalValue<'ctx>>,
498}
499
500impl<'ctx> Globals<'ctx> {
501 fn new(
502 layout: &DataLayout,
503 module: &Module<'ctx>,
504 int_type: IntType<'ctx>,
505 real_type: FloatType<'ctx>,
506 threaded: bool,
507 ) -> Self {
508 let thread_counter = if threaded {
509 let tc = module.add_global(
510 int_type,
511 Some(AddressSpace::default()),
512 "enzyme_const_thread_counter",
513 );
514 let tc_value = int_type.const_zero();
521 tc.set_visibility(GlobalVisibility::Hidden);
522 tc.set_initializer(&tc_value.as_basic_value_enum());
523 Some(tc)
524 } else {
525 None
526 };
527 let constants = if layout.constants().is_empty() {
528 None
529 } else {
530 let constants_array_type =
531 real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
532 let constants = module.add_global(
533 constants_array_type,
534 Some(AddressSpace::default()),
535 "enzyme_const_constants",
536 );
537 constants.set_visibility(GlobalVisibility::Hidden);
538 constants.set_constant(false);
539 constants.set_initializer(&constants_array_type.const_zero());
540 Some(constants)
541 };
542 let indices = if layout.indices().is_empty() {
543 None
544 } else {
545 let indices_array_type =
546 int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
547 let indices_array_values = layout
548 .indices()
549 .iter()
550 .map(|&i| int_type.const_int(i.try_into().unwrap(), false))
551 .collect::<Vec<IntValue>>();
552 let indices_value = int_type.const_array(indices_array_values.as_slice());
553 let indices = module.add_global(
554 indices_array_type,
555 Some(AddressSpace::default()),
556 "enzyme_const_indices",
557 );
558 indices.set_constant(true);
559 indices.set_visibility(GlobalVisibility::Hidden);
560 indices.set_initializer(&indices_value);
561 Some(indices)
562 };
563 Self {
564 indices,
565 thread_counter,
566 constants,
567 }
568 }
569}
570
571pub enum CompileGradientArgType {
572 Const,
573 Dup,
574 DupNoNeed,
575}
576
577pub enum CompileMode {
578 Forward,
579 ForwardSens,
580 Reverse,
581 ReverseSens,
582}
583
584pub struct CodeGen<'ctx> {
585 context: &'ctx inkwell::context::Context,
586 module: Module<'ctx>,
587 builder: Builder<'ctx>,
588 variables: HashMap<String, PointerValue<'ctx>>,
589 functions: HashMap<String, FunctionValue<'ctx>>,
590 fn_value_opt: Option<FunctionValue<'ctx>>,
591 tensor_ptr_opt: Option<PointerValue<'ctx>>,
592 real_type: FloatType<'ctx>,
593 real_ptr_type: PointerType<'ctx>,
594 real_type_str: String,
595 int_type: IntType<'ctx>,
596 int_ptr_type: PointerType<'ctx>,
597 layout: DataLayout,
598 globals: Globals<'ctx>,
599 threaded: bool,
600 _ee: Option<ExecutionEngine<'ctx>>,
601}
602
603unsafe extern "C" fn fwd_handler(
604 _builder: LLVMBuilderRef,
605 _call_instruction: LLVMValueRef,
606 _gutils: *mut GradientUtils,
607 _dcall: *mut LLVMValueRef,
608 _normal_return: *mut LLVMValueRef,
609 _shadow_return: *mut LLVMValueRef,
610) -> u8 {
611 1
612}
613
614unsafe extern "C" fn rev_handler(
615 builder: LLVMBuilderRef,
616 call_instruction: LLVMValueRef,
617 gutils: *mut DiffeGradientUtils,
618 _tape: LLVMValueRef,
619) {
620 let call_block = LLVMGetInstructionParent(call_instruction);
621 let call_function = LLVMGetBasicBlockParent(call_block);
622 let module = LLVMGetGlobalParent(call_function);
623 let name_c_str = CString::new("barrier_grad").unwrap();
624 let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
625 let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
626 let barrier_num = LLVMGetArgOperand(call_instruction, 0);
627 let total_barriers = LLVMGetArgOperand(call_instruction, 1);
628 let thread_count = LLVMGetArgOperand(call_instruction, 2);
629 let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
630 let total_barriers =
631 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
632 let thread_count =
633 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
634 let mut args = [barrier_num, total_barriers, thread_count];
635 let name_c_str = CString::new("").unwrap();
636 LLVMBuildCall2(
637 builder,
638 barrier_func_type,
639 barrier_func,
640 args.as_mut_ptr(),
641 args.len() as u32,
642 name_c_str.as_ptr(),
643 );
644}
645
646#[allow(dead_code)]
647enum PrintValue<'ctx> {
648 Real(FloatValue<'ctx>),
649 Int(IntValue<'ctx>),
650}
651
652impl<'ctx> CodeGen<'ctx> {
653 pub fn new(
654 model: &DiscreteModel,
655 context: &'ctx inkwell::context::Context,
656 real_type: FloatType<'ctx>,
657 int_type: IntType<'ctx>,
658 real_type_str: &str,
659 threaded: bool,
660 ) -> Result<Self> {
661 let builder = context.create_builder();
662 let layout = DataLayout::new(model);
663 let module = context.create_module(model.name());
664 let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
665 let real_ptr_type = Self::pointer_type(context, real_type.into());
666 let int_ptr_type = Self::pointer_type(context, int_type.into());
667 let mut ret = Self {
668 context,
669 module,
670 builder,
671 real_type,
672 real_ptr_type,
673 real_type_str: real_type_str.to_owned(),
674 variables: HashMap::new(),
675 functions: HashMap::new(),
676 fn_value_opt: None,
677 tensor_ptr_opt: None,
678 layout,
679 int_type,
680 int_ptr_type,
681 globals,
682 threaded,
683 _ee: None,
684 };
685 if threaded {
686 ret.compile_barrier_init()?;
687 ret.compile_barrier()?;
688 ret.compile_barrier_grad()?;
689 }
692 Ok(ret)
693 }
694
695 #[allow(dead_code)]
696 fn compile_print_value(
697 &mut self,
698 name: &str,
699 value: PrintValue<'ctx>,
700 ) -> Result<CallSiteValue<'_>> {
701 let void_type = self.context.void_type();
702 let printf_type = void_type.fn_type(&[self.int_ptr_type.into()], true);
704 let printf = match self.module.get_function("printf") {
706 Some(f) => f,
707 None => self
708 .module
709 .add_function("printf", printf_type, Some(Linkage::External)),
710 };
711 let (format_str, format_str_name) = match value {
712 PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
713 PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
714 };
715 let format_str = CString::new(format_str).unwrap();
717 let format_str_global = match self.module.get_global(format_str_name.as_str()) {
719 Some(g) => g,
720 None => {
721 let format_str = self.context.const_string(format_str.as_bytes(), true);
722 let fmt_str =
723 self.module
724 .add_global(format_str.get_type(), None, format_str_name.as_str());
725 fmt_str.set_initializer(&format_str);
726 fmt_str.set_visibility(GlobalVisibility::Hidden);
727 fmt_str
728 }
729 };
730 let format_str_ptr = self.builder.build_pointer_cast(
732 format_str_global.as_pointer_value(),
733 self.int_ptr_type,
734 "format_str_ptr",
735 )?;
736 let value: BasicMetadataValueEnum = match value {
737 PrintValue::Real(v) => v.into(),
738 PrintValue::Int(v) => v.into(),
739 };
740 self.builder
741 .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
742 .map_err(|e| anyhow!("Error building call to printf: {}", e))
743 }
744
745 fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
746 self.clear();
747 let void_type = self.context.void_type();
748 let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false);
749 let fn_arg_names = &["thread_id", "thread_dim"];
750 let function = self.module.add_function("set_constants", fn_type, None);
751
752 let basic_block = self.context.append_basic_block(function, "entry");
753 self.fn_value_opt = Some(function);
754 self.builder.position_at_end(basic_block);
755
756 for (i, arg) in function.get_param_iter().enumerate() {
757 let name = fn_arg_names[i];
758 let alloca = self.function_arg_alloca(name, arg);
759 self.insert_param(name, alloca);
760 }
761
762 self.insert_indices();
763 self.insert_constants(model);
764
765 let mut nbarriers = 0;
766 let total_barriers = (model.constant_defns().len()) as u64;
767 #[allow(clippy::explicit_counter_loop)]
768 for a in model.constant_defns() {
769 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
770 self.jit_compile_call_barrier(nbarriers, total_barriers);
771 nbarriers += 1;
772 }
773
774 self.builder.build_return(None)?;
775
776 if function.verify(true) {
777 Ok(function)
778 } else {
779 function.print_to_stderr();
780 unsafe {
781 function.delete();
782 }
783 Err(anyhow!("Invalid generated function."))
784 }
785 }
786
787 fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
788 self.clear();
789 let void_type = self.context.void_type();
790 let fn_type = void_type.fn_type(&[], false);
791 let function = self.module.add_function("barrier_init", fn_type, None);
792
793 let entry_block = self.context.append_basic_block(function, "entry");
794
795 self.fn_value_opt = Some(function);
796 self.builder.position_at_end(entry_block);
797
798 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
799 self.builder
800 .build_store(thread_counter, self.int_type.const_zero())?;
801
802 self.builder.build_return(None)?;
803
804 if function.verify(true) {
805 self.functions.insert("barrier_init".to_owned(), function);
806 Ok(function)
807 } else {
808 function.print_to_stderr();
809 unsafe {
810 function.delete();
811 }
812 Err(anyhow!("Invalid generated function."))
813 }
814 }
815
816 fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
817 self.clear();
818 let void_type = self.context.void_type();
819 let fn_type = void_type.fn_type(
820 &[
821 self.int_type.into(),
822 self.int_type.into(),
823 self.int_type.into(),
824 ],
825 false,
826 );
827 let function = self.module.add_function("barrier", fn_type, None);
828 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
829 let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
830 function.add_attribute(AttributeLoc::Function, noinline);
831
832 let entry_block = self.context.append_basic_block(function, "entry");
833 let increment_block = self.context.append_basic_block(function, "increment");
834 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
835 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
836
837 self.fn_value_opt = Some(function);
838 self.builder.position_at_end(entry_block);
839
840 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
841 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
842 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
843 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
844
845 let nbarrier_equals_total_barriers = self
846 .builder
847 .build_int_compare(
848 IntPredicate::EQ,
849 barrier_num,
850 total_barriers,
851 "nbarrier_equals_total_barriers",
852 )
853 .unwrap();
854 self.builder.build_conditional_branch(
856 nbarrier_equals_total_barriers,
857 barrier_done_block,
858 increment_block,
859 )?;
860 self.builder.position_at_end(increment_block);
861
862 let barrier_num_times_thread_count = self
863 .builder
864 .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
865 .unwrap();
866
867 let i32_type = self.context.i32_type();
869 let one = i32_type.const_int(1, false);
870 self.builder.build_atomicrmw(
871 AtomicRMWBinOp::Add,
872 thread_counter,
873 one,
874 AtomicOrdering::Monotonic,
875 )?;
876
877 self.builder.build_unconditional_branch(wait_loop_block)?;
879 self.builder.position_at_end(wait_loop_block);
880
881 let current_value = self
882 .builder
883 .build_load(i32_type, thread_counter, "current_value")?
884 .into_int_value();
885
886 current_value
887 .as_instruction_value()
888 .unwrap()
889 .set_atomic_ordering(AtomicOrdering::Monotonic)
890 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
891
892 let all_threads_done = self.builder.build_int_compare(
893 IntPredicate::UGE,
894 current_value,
895 barrier_num_times_thread_count,
896 "all_threads_done",
897 )?;
898
899 self.builder.build_conditional_branch(
900 all_threads_done,
901 barrier_done_block,
902 wait_loop_block,
903 )?;
904 self.builder.position_at_end(barrier_done_block);
905
906 self.builder.build_return(None)?;
907
908 if function.verify(true) {
909 self.functions.insert("barrier".to_owned(), function);
910 Ok(function)
911 } else {
912 function.print_to_stderr();
913 unsafe {
914 function.delete();
915 }
916 Err(anyhow!("Invalid generated function."))
917 }
918 }
919
920 fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
921 self.clear();
922 let void_type = self.context.void_type();
923 let fn_type = void_type.fn_type(
924 &[
925 self.int_type.into(),
926 self.int_type.into(),
927 self.int_type.into(),
928 ],
929 false,
930 );
931 let function = self.module.add_function("barrier_grad", fn_type, None);
932
933 let entry_block = self.context.append_basic_block(function, "entry");
934 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
935 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
936
937 self.fn_value_opt = Some(function);
938 self.builder.position_at_end(entry_block);
939
940 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
941 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
942 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
943 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
944
945 let twice_total_barriers = self
946 .builder
947 .build_int_mul(
948 total_barriers,
949 self.int_type.const_int(2, false),
950 "twice_total_barriers",
951 )
952 .unwrap();
953 let twice_total_barriers_minus_barrier_num = self
954 .builder
955 .build_int_sub(
956 twice_total_barriers,
957 barrier_num,
958 "twice_total_barriers_minus_barrier_num",
959 )
960 .unwrap();
961 let twice_total_barriers_minus_barrier_num_times_thread_count = self
962 .builder
963 .build_int_mul(
964 twice_total_barriers_minus_barrier_num,
965 thread_count,
966 "twice_total_barriers_minus_barrier_num_times_thread_count",
967 )
968 .unwrap();
969
970 let i32_type = self.context.i32_type();
972 let one = i32_type.const_int(1, false);
973 self.builder.build_atomicrmw(
974 AtomicRMWBinOp::Add,
975 thread_counter,
976 one,
977 AtomicOrdering::Monotonic,
978 )?;
979
980 self.builder.build_unconditional_branch(wait_loop_block)?;
982 self.builder.position_at_end(wait_loop_block);
983
984 let current_value = self
985 .builder
986 .build_load(i32_type, thread_counter, "current_value")?
987 .into_int_value();
988 current_value
989 .as_instruction_value()
990 .unwrap()
991 .set_atomic_ordering(AtomicOrdering::Monotonic)
992 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
993
994 let all_threads_done = self.builder.build_int_compare(
995 IntPredicate::UGE,
996 current_value,
997 twice_total_barriers_minus_barrier_num_times_thread_count,
998 "all_threads_done",
999 )?;
1000
1001 self.builder.build_conditional_branch(
1002 all_threads_done,
1003 barrier_done_block,
1004 wait_loop_block,
1005 )?;
1006 self.builder.position_at_end(barrier_done_block);
1007
1008 self.builder.build_return(None)?;
1009
1010 if function.verify(true) {
1011 self.functions.insert("barrier_grad".to_owned(), function);
1012 Ok(function)
1013 } else {
1014 function.print_to_stderr();
1015 unsafe {
1016 function.delete();
1017 }
1018 Err(anyhow!("Invalid generated function."))
1019 }
1020 }
1021
1022 fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1023 if !self.threaded {
1024 return;
1025 }
1026 let thread_dim = self.get_param("thread_dim");
1027 let thread_dim = self
1028 .builder
1029 .build_load(self.int_type, *thread_dim, "thread_dim")
1030 .unwrap()
1031 .into_int_value();
1032 let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1033 let total_barriers = self.int_type.const_int(total_barriers, false);
1034 let barrier = self.get_function("barrier").unwrap();
1035 self.builder
1036 .build_call(
1037 barrier,
1038 &[
1039 BasicMetadataValueEnum::IntValue(nbarrier),
1040 BasicMetadataValueEnum::IntValue(total_barriers),
1041 BasicMetadataValueEnum::IntValue(thread_dim),
1042 ],
1043 "barrier",
1044 )
1045 .unwrap();
1046 }
1047
1048 fn jit_threading_limits(
1049 &mut self,
1050 size: IntValue<'ctx>,
1051 ) -> Result<(
1052 IntValue<'ctx>,
1053 IntValue<'ctx>,
1054 BasicBlock<'ctx>,
1055 BasicBlock<'ctx>,
1056 )> {
1057 let one = self.int_type.const_int(1, false);
1058 let thread_id = self.get_param("thread_id");
1059 let thread_id = self
1060 .builder
1061 .build_load(self.int_type, *thread_id, "thread_id")
1062 .unwrap()
1063 .into_int_value();
1064 let thread_dim = self.get_param("thread_dim");
1065 let thread_dim = self
1066 .builder
1067 .build_load(self.int_type, *thread_dim, "thread_dim")
1068 .unwrap()
1069 .into_int_value();
1070
1071 let i_times_size = self
1073 .builder
1074 .build_int_mul(thread_id, size, "i_times_size")?;
1075 let start = self
1076 .builder
1077 .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1078
1079 let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1081 let i_plus_one_times_size =
1082 self.builder
1083 .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1084 let end = self
1085 .builder
1086 .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1087
1088 let test_done = self.builder.get_insert_block().unwrap();
1089 let next_block = self
1090 .context
1091 .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1092 self.builder.position_at_end(next_block);
1093
1094 Ok((start, end, test_done, next_block))
1095 }
1096
1097 fn jit_end_threading(
1098 &mut self,
1099 start: IntValue<'ctx>,
1100 end: IntValue<'ctx>,
1101 test_done: BasicBlock<'ctx>,
1102 next: BasicBlock<'ctx>,
1103 ) -> Result<()> {
1104 let exit = self
1105 .context
1106 .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1107 self.builder.build_unconditional_branch(exit)?;
1108 self.builder.position_at_end(test_done);
1109 let done = self
1111 .builder
1112 .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1113 self.builder.build_conditional_branch(done, exit, next)?;
1114 self.builder.position_at_end(exit);
1115 Ok(())
1116 }
1117
1118 pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1119 self.module.write_bitcode_to_path(path);
1120 }
1121
1122 fn insert_constants(&mut self, model: &DiscreteModel) {
1123 if let Some(constants) = self.globals.constants.as_ref() {
1124 self.insert_param("constants", constants.as_pointer_value());
1125 for tensor in model.constant_defns() {
1126 self.insert_tensor(tensor, true);
1127 }
1128 }
1129 }
1130
1131 fn insert_data(&mut self, model: &DiscreteModel) {
1132 self.insert_constants(model);
1133
1134 for tensor in model.inputs() {
1135 self.insert_tensor(tensor, false);
1136 }
1137 for tensor in model.input_dep_defns() {
1138 self.insert_tensor(tensor, false);
1139 }
1140 for tensor in model.time_dep_defns() {
1141 self.insert_tensor(tensor, false);
1142 }
1143 for tensor in model.state_dep_defns() {
1144 self.insert_tensor(tensor, false);
1145 }
1146 }
1147
1148 fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1149 context.ptr_type(AddressSpace::default())
1150 }
1151
1152 fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1153 context.ptr_type(AddressSpace::default())
1154 }
1155
1156 fn insert_indices(&mut self) {
1157 if let Some(indices) = self.globals.indices.as_ref() {
1158 let i32_type = self.context.i32_type();
1159 let zero = i32_type.const_int(0, false);
1160 let ptr = unsafe {
1161 indices
1162 .as_pointer_value()
1163 .const_in_bounds_gep(i32_type, &[zero])
1164 };
1165 self.variables.insert("indices".to_owned(), ptr);
1166 }
1167 }
1168
1169 fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1170 self.variables.insert(name.to_owned(), value);
1171 }
1172
1173 fn build_gep<T: BasicType<'ctx>>(
1174 &self,
1175 ty: T,
1176 ptr: PointerValue<'ctx>,
1177 ordered_indexes: &[IntValue<'ctx>],
1178 name: &str,
1179 ) -> Result<PointerValue<'ctx>> {
1180 unsafe {
1181 self.builder
1182 .build_gep(ty, ptr, ordered_indexes, name)
1183 .map_err(|e| e.into())
1184 }
1185 }
1186
1187 fn build_load<T: BasicType<'ctx>>(
1188 &self,
1189 ty: T,
1190 ptr: PointerValue<'ctx>,
1191 name: &str,
1192 ) -> Result<BasicValueEnum<'ctx>> {
1193 self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1194 }
1195
1196 fn get_ptr_to_index<T: BasicType<'ctx>>(
1197 builder: &Builder<'ctx>,
1198 ty: T,
1199 ptr: &PointerValue<'ctx>,
1200 index: IntValue<'ctx>,
1201 name: &str,
1202 ) -> PointerValue<'ctx> {
1203 unsafe {
1204 builder
1205 .build_in_bounds_gep(ty, *ptr, &[index], name)
1206 .unwrap()
1207 }
1208 }
1209
1210 fn insert_state(&mut self, u: &Tensor) {
1211 let mut data_index = 0;
1212 for blk in u.elmts() {
1213 if let Some(name) = blk.name() {
1214 let ptr = self.variables.get("u").unwrap();
1215 let i = self
1216 .context
1217 .i32_type()
1218 .const_int(data_index.try_into().unwrap(), false);
1219 let alloca = Self::get_ptr_to_index(
1220 &self.create_entry_block_builder(),
1221 self.real_type,
1222 ptr,
1223 i,
1224 blk.name().unwrap(),
1225 );
1226 self.variables.insert(name.to_owned(), alloca);
1227 }
1228 data_index += blk.nnz();
1229 }
1230 }
1231 fn insert_dot_state(&mut self, dudt: &Tensor) {
1232 let mut data_index = 0;
1233 for blk in dudt.elmts() {
1234 if let Some(name) = blk.name() {
1235 let ptr = self.variables.get("dudt").unwrap();
1236 let i = self
1237 .context
1238 .i32_type()
1239 .const_int(data_index.try_into().unwrap(), false);
1240 let alloca = Self::get_ptr_to_index(
1241 &self.create_entry_block_builder(),
1242 self.real_type,
1243 ptr,
1244 i,
1245 blk.name().unwrap(),
1246 );
1247 self.variables.insert(name.to_owned(), alloca);
1248 }
1249 data_index += blk.nnz();
1250 }
1251 }
1252 fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1253 let var_name = if is_constant { "constants" } else { "data" };
1254 let ptr = *self.variables.get(var_name).unwrap();
1255 let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1256 let i = self
1257 .context
1258 .i32_type()
1259 .const_int(data_index.try_into().unwrap(), false);
1260 let alloca = Self::get_ptr_to_index(
1261 &self.create_entry_block_builder(),
1262 self.real_type,
1263 &ptr,
1264 i,
1265 tensor.name(),
1266 );
1267 self.variables.insert(tensor.name().to_owned(), alloca);
1268
1269 for blk in tensor.elmts() {
1271 if let Some(name) = blk.name() {
1272 let i = self
1273 .context
1274 .i32_type()
1275 .const_int(data_index.try_into().unwrap(), false);
1276 let alloca = Self::get_ptr_to_index(
1277 &self.create_entry_block_builder(),
1278 self.real_type,
1279 &ptr,
1280 i,
1281 name,
1282 );
1283 self.variables.insert(name.to_owned(), alloca);
1284 }
1285 data_index += blk.nnz();
1287 }
1288 }
1289 fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1290 self.variables.get(name).unwrap()
1291 }
1292
1293 fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1294 self.variables.get(tensor.name()).unwrap()
1295 }
1296
1297 fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1298 match self.functions.get(name) {
1299 Some(&func) => Some(func),
1300 None => {
1301 let function = match name {
1302 "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs"
1304 | "copysign" | "pow" | "min" | "max" => {
1305 let arg_len = 1;
1306 let intrinsic_name = match name {
1307 "min" => "minnum",
1308 "max" => "maxnum",
1309 "abs" => "fabs",
1310 _ => name,
1311 };
1312 let llvm_name = format!("llvm.{}.{}", intrinsic_name, self.real_type_str);
1313 let intrinsic = Intrinsic::find(&llvm_name).unwrap();
1314 let ret_type = self.real_type;
1315
1316 let args_types = std::iter::repeat_n(ret_type, arg_len)
1317 .map(|f| f.into())
1318 .collect::<Vec<BasicTypeEnum>>();
1319 return intrinsic.get_declaration(&self.module, args_types.as_slice());
1321 }
1322 "sigmoid" => {
1324 let arg_len = 1;
1325 let ret_type = self.real_type;
1326
1327 let args_types = std::iter::repeat_n(ret_type, arg_len)
1328 .map(|f| f.into())
1329 .collect::<Vec<BasicMetadataTypeEnum>>();
1330
1331 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1332 let fn_val = self.module.add_function(name, fn_type, None);
1333
1334 for arg in fn_val.get_param_iter() {
1335 arg.into_float_value().set_name("x");
1336 }
1337
1338 let current_block = self.builder.get_insert_block().unwrap();
1339 let basic_block = self.context.append_basic_block(fn_val, "entry");
1340 self.builder.position_at_end(basic_block);
1341 let x = fn_val.get_nth_param(0)?.into_float_value();
1342 let one = self.real_type.const_float(1.0);
1343 let negx = self.builder.build_float_neg(x, name).ok()?;
1344 let exp = self.get_function("exp").unwrap();
1345 let exp_negx = self
1346 .builder
1347 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1348 .ok()?;
1349 let one_plus_exp_negx = self
1350 .builder
1351 .build_float_add(
1352 exp_negx
1353 .try_as_basic_value()
1354 .left()
1355 .unwrap()
1356 .into_float_value(),
1357 one,
1358 name,
1359 )
1360 .ok()?;
1361 let sigmoid = self
1362 .builder
1363 .build_float_div(one, one_plus_exp_negx, name)
1364 .ok()?;
1365 self.builder.build_return(Some(&sigmoid)).ok();
1366 self.builder.position_at_end(current_block);
1367 Some(fn_val)
1368 }
1369 "arcsinh" | "arccosh" => {
1370 let arg_len = 1;
1371 let ret_type = self.real_type;
1372
1373 let args_types = std::iter::repeat_n(ret_type, arg_len)
1374 .map(|f| f.into())
1375 .collect::<Vec<BasicMetadataTypeEnum>>();
1376
1377 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1378 let fn_val = self.module.add_function(name, fn_type, None);
1379
1380 for arg in fn_val.get_param_iter() {
1381 arg.into_float_value().set_name("x");
1382 }
1383
1384 let current_block = self.builder.get_insert_block().unwrap();
1385 let basic_block = self.context.append_basic_block(fn_val, "entry");
1386 self.builder.position_at_end(basic_block);
1387 let x = fn_val.get_nth_param(0)?.into_float_value();
1388 let one = match name {
1389 "arccosh" => self.real_type.const_float(-1.0),
1390 "arcsinh" => self.real_type.const_float(1.0),
1391 _ => panic!("unknown function"),
1392 };
1393 let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1394 let one_plus_x_squared =
1395 self.builder.build_float_add(x_squared, one, name).ok()?;
1396 let sqrt = self.get_function("sqrt").unwrap();
1397 let sqrt_one_plus_x_squared = self
1398 .builder
1399 .build_call(
1400 sqrt,
1401 &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1402 name,
1403 )
1404 .unwrap()
1405 .try_as_basic_value()
1406 .left()
1407 .unwrap()
1408 .into_float_value();
1409 let x_plus_sqrt_one_plus_x_squared = self
1410 .builder
1411 .build_float_add(x, sqrt_one_plus_x_squared, name)
1412 .ok()?;
1413 let ln = self.get_function("log").unwrap();
1414 let result = self
1415 .builder
1416 .build_call(
1417 ln,
1418 &[BasicMetadataValueEnum::FloatValue(
1419 x_plus_sqrt_one_plus_x_squared,
1420 )],
1421 name,
1422 )
1423 .unwrap()
1424 .try_as_basic_value()
1425 .left()
1426 .unwrap()
1427 .into_float_value();
1428 self.builder.build_return(Some(&result)).ok();
1429 self.builder.position_at_end(current_block);
1430 Some(fn_val)
1431 }
1432 "heaviside" => {
1433 let arg_len = 1;
1434 let ret_type = self.real_type;
1435
1436 let args_types = std::iter::repeat_n(ret_type, arg_len)
1437 .map(|f| f.into())
1438 .collect::<Vec<BasicMetadataTypeEnum>>();
1439
1440 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1441 let fn_val = self.module.add_function(name, fn_type, None);
1442
1443 for arg in fn_val.get_param_iter() {
1444 arg.into_float_value().set_name("x");
1445 }
1446
1447 let current_block = self.builder.get_insert_block().unwrap();
1448 let basic_block = self.context.append_basic_block(fn_val, "entry");
1449 self.builder.position_at_end(basic_block);
1450 let x = fn_val.get_nth_param(0)?.into_float_value();
1451 let zero = self.real_type.const_float(0.0);
1452 let one = self.real_type.const_float(1.0);
1453 let result = self
1454 .builder
1455 .build_select(
1456 self.builder
1457 .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1458 .unwrap(),
1459 one,
1460 zero,
1461 name,
1462 )
1463 .ok()?;
1464 self.builder.build_return(Some(&result)).ok();
1465 self.builder.position_at_end(current_block);
1466 Some(fn_val)
1467 }
1468 "tanh" | "sinh" | "cosh" => {
1469 let arg_len = 1;
1470 let ret_type = self.real_type;
1471
1472 let args_types = std::iter::repeat_n(ret_type, arg_len)
1473 .map(|f| f.into())
1474 .collect::<Vec<BasicMetadataTypeEnum>>();
1475
1476 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1477 let fn_val = self.module.add_function(name, fn_type, None);
1478
1479 for arg in fn_val.get_param_iter() {
1480 arg.into_float_value().set_name("x");
1481 }
1482
1483 let current_block = self.builder.get_insert_block().unwrap();
1484 let basic_block = self.context.append_basic_block(fn_val, "entry");
1485 self.builder.position_at_end(basic_block);
1486 let x = fn_val.get_nth_param(0)?.into_float_value();
1487 let negx = self.builder.build_float_neg(x, name).ok()?;
1488 let exp = self.get_function("exp").unwrap();
1489 let exp_negx = self
1490 .builder
1491 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1492 .ok()?;
1493 let expx = self
1494 .builder
1495 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1496 .ok()?;
1497 let expx_minus_exp_negx = self
1498 .builder
1499 .build_float_sub(
1500 expx.try_as_basic_value().left().unwrap().into_float_value(),
1501 exp_negx
1502 .try_as_basic_value()
1503 .left()
1504 .unwrap()
1505 .into_float_value(),
1506 name,
1507 )
1508 .ok()?;
1509 let expx_plus_exp_negx = self
1510 .builder
1511 .build_float_add(
1512 expx.try_as_basic_value().left().unwrap().into_float_value(),
1513 exp_negx
1514 .try_as_basic_value()
1515 .left()
1516 .unwrap()
1517 .into_float_value(),
1518 name,
1519 )
1520 .ok()?;
1521 let result = match name {
1522 "tanh" => self
1523 .builder
1524 .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1525 .ok()?,
1526 "sinh" => self
1527 .builder
1528 .build_float_div(
1529 expx_minus_exp_negx,
1530 self.real_type.const_float(2.0),
1531 name,
1532 )
1533 .ok()?,
1534 "cosh" => self
1535 .builder
1536 .build_float_div(
1537 expx_plus_exp_negx,
1538 self.real_type.const_float(2.0),
1539 name,
1540 )
1541 .ok()?,
1542 _ => panic!("unknown function"),
1543 };
1544 self.builder.build_return(Some(&result)).ok();
1545 self.builder.position_at_end(current_block);
1546 Some(fn_val)
1547 }
1548 _ => None,
1549 }?;
1550 self.functions.insert(name.to_owned(), function);
1551 Some(function)
1552 }
1553 }
1554 }
1555 #[inline]
1557 fn fn_value(&self) -> FunctionValue<'ctx> {
1558 self.fn_value_opt.unwrap()
1559 }
1560
1561 #[inline]
1562 fn tensor_ptr(&self) -> PointerValue<'ctx> {
1563 self.tensor_ptr_opt.unwrap()
1564 }
1565
1566 fn create_entry_block_builder(&self) -> Builder<'ctx> {
1568 let builder = self.context.create_builder();
1569 let entry = self.fn_value().get_first_basic_block().unwrap();
1570 match entry.get_first_instruction() {
1571 Some(first_instr) => builder.position_before(&first_instr),
1572 None => builder.position_at_end(entry),
1573 }
1574 builder
1575 }
1576
1577 fn jit_compile_scalar(
1578 &mut self,
1579 a: &Tensor,
1580 res_ptr_opt: Option<PointerValue<'ctx>>,
1581 ) -> Result<PointerValue<'ctx>> {
1582 let res_type = self.real_type;
1583 let res_ptr = match res_ptr_opt {
1584 Some(ptr) => ptr,
1585 None => self
1586 .create_entry_block_builder()
1587 .build_alloca(res_type, a.name())?,
1588 };
1589 let name = a.name();
1590 let elmt = a.elmts().first().unwrap();
1591
1592 let curr_block = self.builder.get_insert_block().unwrap();
1594 let mut next_block_opt = None;
1595 if self.threaded {
1596 let next_block = self.context.append_basic_block(self.fn_value(), "next");
1597 self.builder.position_at_end(next_block);
1598 next_block_opt = Some(next_block);
1599 }
1600
1601 let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, None)?;
1602 self.builder.build_store(res_ptr, float_value)?;
1603
1604 if self.threaded {
1606 let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1607 self.builder.build_unconditional_branch(exit_block)?;
1608 self.builder.position_at_end(curr_block);
1609
1610 let thread_id = self.get_param("thread_id");
1611 let thread_id = self
1612 .builder
1613 .build_load(self.int_type, *thread_id, "thread_id")
1614 .unwrap()
1615 .into_int_value();
1616 let is_first_thread = self.builder.build_int_compare(
1617 IntPredicate::EQ,
1618 thread_id,
1619 self.int_type.const_zero(),
1620 "is_first_thread",
1621 )?;
1622 self.builder.build_conditional_branch(
1623 is_first_thread,
1624 next_block_opt.unwrap(),
1625 exit_block,
1626 )?;
1627
1628 self.builder.position_at_end(exit_block);
1629 }
1630
1631 Ok(res_ptr)
1632 }
1633
1634 fn jit_compile_tensor(
1635 &mut self,
1636 a: &Tensor,
1637 res_ptr_opt: Option<PointerValue<'ctx>>,
1638 ) -> Result<PointerValue<'ctx>> {
1639 if a.rank() == 0 {
1641 return self.jit_compile_scalar(a, res_ptr_opt);
1642 }
1643
1644 let res_type = self.real_type;
1645 let res_ptr = match res_ptr_opt {
1646 Some(ptr) => ptr,
1647 None => self
1648 .create_entry_block_builder()
1649 .build_alloca(res_type, a.name())?,
1650 };
1651
1652 self.tensor_ptr_opt = Some(res_ptr);
1654
1655 for (i, blk) in a.elmts().iter().enumerate() {
1656 let default = format!("{}-{}", a.name(), i);
1657 let name = blk.name().unwrap_or(default.as_str());
1658 self.jit_compile_block(name, a, blk)?;
1659 }
1660 Ok(res_ptr)
1661 }
1662
1663 fn jit_compile_block(&mut self, name: &str, tensor: &Tensor, elmt: &TensorBlock) -> Result<()> {
1664 let translation = Translation::new(
1665 elmt.expr_layout(),
1666 elmt.layout(),
1667 elmt.start(),
1668 tensor.layout_ptr(),
1669 );
1670
1671 if elmt.expr_layout().is_dense() {
1672 self.jit_compile_dense_block(name, elmt, &translation)
1673 } else if elmt.expr_layout().is_diagonal() {
1674 self.jit_compile_diagonal_block(name, elmt, &translation)
1675 } else if elmt.expr_layout().is_sparse() {
1676 match translation.source {
1677 TranslationFrom::SparseContraction { .. } => {
1678 self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1679 }
1680 _ => self.jit_compile_sparse_block(name, elmt, &translation),
1681 }
1682 } else {
1683 Err(anyhow!(
1684 "unsupported block layout: {:?}",
1685 elmt.expr_layout()
1686 ))
1687 }
1688 }
1689
1690 fn jit_compile_dense_block(
1692 &mut self,
1693 name: &str,
1694 elmt: &TensorBlock,
1695 translation: &Translation,
1696 ) -> Result<()> {
1697 let int_type = self.int_type;
1698
1699 let mut preblock = self.builder.get_insert_block().unwrap();
1700 let expr_rank = elmt.expr_layout().rank();
1701 let expr_shape = elmt
1702 .expr_layout()
1703 .shape()
1704 .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1705 let one = int_type.const_int(1, false);
1706
1707 let mut expr_strides = vec![1; expr_rank];
1708 if expr_rank > 0 {
1709 for i in (0..expr_rank - 1).rev() {
1710 expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1711 }
1712 }
1713 let expr_strides = expr_strides
1714 .iter()
1715 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1716 .collect::<Vec<IntValue>>();
1717
1718 let mut indices = Vec::new();
1720 let mut blocks = Vec::new();
1721
1722 let (contract_sum, contract_by, contract_strides) =
1724 if let TranslationFrom::DenseContraction {
1725 contract_by,
1726 contract_len: _,
1727 } = translation.source
1728 {
1729 let contract_rank = expr_rank - contract_by;
1730 let mut contract_strides = vec![1; contract_rank];
1731 for i in (0..contract_rank - 1).rev() {
1732 contract_strides[i] =
1733 contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1734 }
1735 let contract_strides = contract_strides
1736 .iter()
1737 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1738 .collect::<Vec<IntValue>>();
1739 (
1740 Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1741 contract_by,
1742 Some(contract_strides),
1743 )
1744 } else {
1745 (None, 0, None)
1746 };
1747
1748 let (thread_start, thread_end, test_done, next) = if self.threaded {
1750 let (start, end, test_done, next) =
1751 self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1752 preblock = next;
1753 (Some(start), Some(end), Some(test_done), Some(next))
1754 } else {
1755 (None, None, None, None)
1756 };
1757
1758 for i in 0..expr_rank {
1759 let block = self.context.append_basic_block(self.fn_value(), name);
1760 self.builder.build_unconditional_branch(block)?;
1761 self.builder.position_at_end(block);
1762
1763 let start_index = if i == 0 && self.threaded {
1764 thread_start.unwrap()
1765 } else {
1766 self.int_type.const_zero()
1767 };
1768
1769 let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1770 curr_index.add_incoming(&[(&start_index, preblock)]);
1771
1772 if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1773 self.builder
1774 .build_store(contract_sum.unwrap(), self.real_type.const_zero())?;
1775 }
1776
1777 indices.push(curr_index);
1778 blocks.push(block);
1779 preblock = block;
1780 }
1781
1782 let indices_int: Vec<IntValue> = indices
1783 .iter()
1784 .map(|i| i.as_basic_value().into_int_value())
1785 .collect();
1786
1787 let float_value =
1788 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, None)?;
1789
1790 if contract_sum.is_some() {
1791 let contract_sum_value = self
1792 .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1793 .into_float_value();
1794 let new_contract_sum_value = self.builder.build_float_add(
1795 contract_sum_value,
1796 float_value,
1797 "new_contract_sum",
1798 )?;
1799 self.builder
1800 .build_store(contract_sum.unwrap(), new_contract_sum_value)?;
1801 } else {
1802 let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
1803 self.int_type.const_zero(),
1804 |acc, (i, s)| {
1805 let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
1806 self.builder.build_int_add(acc, tmp, "acc").unwrap()
1807 },
1808 );
1809 preblock = self.jit_compile_broadcast_and_store(
1810 name,
1811 elmt,
1812 float_value,
1813 expr_index,
1814 translation,
1815 preblock,
1816 )?;
1817 }
1818
1819 for i in (0..expr_rank).rev() {
1821 let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
1823 indices[i].add_incoming(&[(&next_index, preblock)]);
1824
1825 if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1826 let contract_sum_value = self
1827 .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1828 .into_float_value();
1829 let contract_strides = contract_strides.as_ref().unwrap();
1830 let elmt_index = indices_int
1831 .iter()
1832 .take(contract_strides.len())
1833 .zip(contract_strides.iter())
1834 .fold(self.int_type.const_zero(), |acc, (i, s)| {
1835 let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
1836 self.builder.build_int_add(acc, tmp, "acc").unwrap()
1837 });
1838 self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?;
1839 }
1840
1841 let end_index = if i == 0 && self.threaded {
1842 thread_end.unwrap()
1843 } else {
1844 expr_shape[i]
1845 };
1846
1847 let loop_while =
1849 self.builder
1850 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
1851 let block = self.context.append_basic_block(self.fn_value(), name);
1852 self.builder
1853 .build_conditional_branch(loop_while, blocks[i], block)?;
1854 self.builder.position_at_end(block);
1855 preblock = block;
1856 }
1857
1858 if self.threaded {
1859 self.jit_end_threading(
1860 thread_start.unwrap(),
1861 thread_end.unwrap(),
1862 test_done.unwrap(),
1863 next.unwrap(),
1864 )?;
1865 }
1866 Ok(())
1867 }
1868
1869 fn jit_compile_sparse_contraction_block(
1870 &mut self,
1871 name: &str,
1872 elmt: &TensorBlock,
1873 translation: &Translation,
1874 ) -> Result<()> {
1875 match translation.source {
1876 TranslationFrom::SparseContraction { .. } => {}
1877 _ => {
1878 panic!("expected sparse contraction")
1879 }
1880 }
1881 let int_type = self.int_type;
1882
1883 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
1884 let translation_index = self
1885 .layout
1886 .get_translation_index(elmt.expr_layout(), elmt.layout())
1887 .unwrap();
1888 let translation_index = translation_index + translation.get_from_index_in_data_layout();
1889
1890 let final_contract_index =
1891 int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
1892 let (thread_start, thread_end, test_done, next) = if self.threaded {
1893 let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
1894 (Some(start), Some(end), Some(test_done), Some(next))
1895 } else {
1896 (None, None, None, None)
1897 };
1898
1899 let preblock = self.builder.get_insert_block().unwrap();
1900 let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
1901
1902 let block = self.context.append_basic_block(self.fn_value(), name);
1904 self.builder.build_unconditional_branch(block)?;
1905 self.builder.position_at_end(block);
1906
1907 let contract_index = self.builder.build_phi(int_type, "i")?;
1908 let contract_start = if self.threaded {
1909 thread_start.unwrap()
1910 } else {
1911 int_type.const_zero()
1912 };
1913 contract_index.add_incoming(&[(&contract_start, preblock)]);
1914
1915 let start_index = self.builder.build_int_add(
1916 int_type.const_int(translation_index.try_into().unwrap(), false),
1917 self.builder.build_int_mul(
1918 int_type.const_int(2, false),
1919 contract_index.as_basic_value().into_int_value(),
1920 name,
1921 )?,
1922 name,
1923 )?;
1924 let end_index =
1925 self.builder
1926 .build_int_add(start_index, int_type.const_int(1, false), name)?;
1927 let start_ptr = self.build_gep(
1928 self.int_type,
1929 *self.get_param("indices"),
1930 &[start_index],
1931 "start_index_ptr",
1932 )?;
1933 let start_contract = self
1934 .build_load(self.int_type, start_ptr, "start")?
1935 .into_int_value();
1936 let end_ptr = self.build_gep(
1937 self.int_type,
1938 *self.get_param("indices"),
1939 &[end_index],
1940 "end_index_ptr",
1941 )?;
1942 let end_contract = self
1943 .build_load(self.int_type, end_ptr, "end")?
1944 .into_int_value();
1945
1946 self.builder
1948 .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
1949
1950 let contract_block = self
1952 .context
1953 .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
1954 self.builder.build_unconditional_branch(contract_block)?;
1955 self.builder.position_at_end(contract_block);
1956
1957 let expr_index_phi = self.builder.build_phi(int_type, "j")?;
1958 expr_index_phi.add_incoming(&[(&start_contract, block)]);
1959
1960 let expr_index = expr_index_phi.as_basic_value().into_int_value();
1962 let elmt_index_mult_rank = self.builder.build_int_mul(
1963 expr_index,
1964 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
1965 name,
1966 )?;
1967 let indices_int = (0..elmt.expr_layout().rank())
1968 .map(|i| {
1969 let layout_index_plus_offset =
1970 int_type.const_int((layout_index + i).try_into().unwrap(), false);
1971 let curr_index = self.builder.build_int_add(
1972 elmt_index_mult_rank,
1973 layout_index_plus_offset,
1974 name,
1975 )?;
1976 let ptr = Self::get_ptr_to_index(
1977 &self.builder,
1978 self.int_type,
1979 self.get_param("indices"),
1980 curr_index,
1981 name,
1982 );
1983 let index = self.build_load(self.int_type, ptr, name)?.into_int_value();
1984 Ok(index)
1985 })
1986 .collect::<Result<Vec<_>, anyhow::Error>>()?;
1987
1988 let float_value = self.jit_compile_expr(
1990 name,
1991 elmt.expr(),
1992 indices_int.as_slice(),
1993 elmt,
1994 Some(expr_index),
1995 )?;
1996 let contract_sum_value = self
1997 .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
1998 .into_float_value();
1999 let new_contract_sum_value =
2000 self.builder
2001 .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
2002 self.builder
2003 .build_store(contract_sum_ptr, new_contract_sum_value)?;
2004
2005 let next_elmt_index =
2007 self.builder
2008 .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2009 expr_index_phi.add_incoming(&[(&next_elmt_index, contract_block)]);
2010
2011 let loop_while = self.builder.build_int_compare(
2013 IntPredicate::ULT,
2014 next_elmt_index,
2015 end_contract,
2016 name,
2017 )?;
2018 let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2019 self.builder
2020 .build_conditional_branch(loop_while, contract_block, post_contract_block)?;
2021 self.builder.position_at_end(post_contract_block);
2022
2023 self.jit_compile_store(
2025 name,
2026 elmt,
2027 contract_index.as_basic_value().into_int_value(),
2028 new_contract_sum_value,
2029 translation,
2030 )?;
2031
2032 let next_contract_index = self.builder.build_int_add(
2034 contract_index.as_basic_value().into_int_value(),
2035 int_type.const_int(1, false),
2036 name,
2037 )?;
2038 contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2039
2040 let loop_while = self.builder.build_int_compare(
2042 IntPredicate::ULT,
2043 next_contract_index,
2044 thread_end.unwrap_or(final_contract_index),
2045 name,
2046 )?;
2047 let post_block = self.context.append_basic_block(self.fn_value(), name);
2048 self.builder
2049 .build_conditional_branch(loop_while, block, post_block)?;
2050 self.builder.position_at_end(post_block);
2051
2052 if self.threaded {
2053 self.jit_end_threading(
2054 thread_start.unwrap(),
2055 thread_end.unwrap(),
2056 test_done.unwrap(),
2057 next.unwrap(),
2058 )?;
2059 }
2060
2061 Ok(())
2062 }
2063
2064 fn jit_compile_sparse_block(
2067 &mut self,
2068 name: &str,
2069 elmt: &TensorBlock,
2070 translation: &Translation,
2071 ) -> Result<()> {
2072 let int_type = self.int_type;
2073
2074 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2075 let start_index = int_type.const_int(0, false);
2076 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2077
2078 let (thread_start, thread_end, test_done, next) = if self.threaded {
2079 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2080 (Some(start), Some(end), Some(test_done), Some(next))
2081 } else {
2082 (None, None, None, None)
2083 };
2084
2085 let preblock = self.builder.get_insert_block().unwrap();
2087 let mut block = self.context.append_basic_block(self.fn_value(), name);
2088 self.builder.build_unconditional_branch(block)?;
2089 self.builder.position_at_end(block);
2090
2091 let curr_index = self.builder.build_phi(int_type, "i")?;
2092 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2093
2094 let elmt_index = curr_index.as_basic_value().into_int_value();
2096 let elmt_index_mult_rank = self.builder.build_int_mul(
2097 elmt_index,
2098 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2099 name,
2100 )?;
2101 let indices_int = (0..elmt.expr_layout().rank())
2102 .map(|i| {
2103 let layout_index_plus_offset =
2104 int_type.const_int((layout_index + i).try_into().unwrap(), false);
2105 let curr_index = self.builder.build_int_add(
2106 elmt_index_mult_rank,
2107 layout_index_plus_offset,
2108 name,
2109 )?;
2110 let ptr = Self::get_ptr_to_index(
2111 &self.builder,
2112 self.int_type,
2113 self.get_param("indices"),
2114 curr_index,
2115 name,
2116 );
2117 Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2118 })
2119 .collect::<Result<Vec<_>, anyhow::Error>>()?;
2120
2121 let float_value = self.jit_compile_expr(
2123 name,
2124 elmt.expr(),
2125 indices_int.as_slice(),
2126 elmt,
2127 Some(elmt_index),
2128 )?;
2129
2130 block = self.jit_compile_broadcast_and_store(
2131 name,
2132 elmt,
2133 float_value,
2134 elmt_index,
2135 translation,
2136 block,
2137 )?;
2138
2139 let one = int_type.const_int(1, false);
2141 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2142 curr_index.add_incoming(&[(&next_index, block)]);
2143
2144 let loop_while = self.builder.build_int_compare(
2146 IntPredicate::ULT,
2147 next_index,
2148 thread_end.unwrap_or(end_index),
2149 name,
2150 )?;
2151 let post_block = self.context.append_basic_block(self.fn_value(), name);
2152 self.builder
2153 .build_conditional_branch(loop_while, block, post_block)?;
2154 self.builder.position_at_end(post_block);
2155
2156 if self.threaded {
2157 self.jit_end_threading(
2158 thread_start.unwrap(),
2159 thread_end.unwrap(),
2160 test_done.unwrap(),
2161 next.unwrap(),
2162 )?;
2163 }
2164
2165 Ok(())
2166 }
2167
2168 fn jit_compile_diagonal_block(
2170 &mut self,
2171 name: &str,
2172 elmt: &TensorBlock,
2173 translation: &Translation,
2174 ) -> Result<()> {
2175 let int_type = self.int_type;
2176
2177 let start_index = int_type.const_int(0, false);
2178 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2179
2180 let (thread_start, thread_end, test_done, next) = if self.threaded {
2181 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2182 (Some(start), Some(end), Some(test_done), Some(next))
2183 } else {
2184 (None, None, None, None)
2185 };
2186
2187 let preblock = self.builder.get_insert_block().unwrap();
2189 let mut block = self.context.append_basic_block(self.fn_value(), name);
2190 self.builder.build_unconditional_branch(block)?;
2191 self.builder.position_at_end(block);
2192
2193 let curr_index = self.builder.build_phi(int_type, "i")?;
2194 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2195
2196 let elmt_index = curr_index.as_basic_value().into_int_value();
2198 let indices_int: Vec<IntValue> =
2199 (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2200
2201 let float_value = self.jit_compile_expr(
2203 name,
2204 elmt.expr(),
2205 indices_int.as_slice(),
2206 elmt,
2207 Some(elmt_index),
2208 )?;
2209
2210 block = self.jit_compile_broadcast_and_store(
2212 name,
2213 elmt,
2214 float_value,
2215 elmt_index,
2216 translation,
2217 block,
2218 )?;
2219
2220 let one = int_type.const_int(1, false);
2222 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2223 curr_index.add_incoming(&[(&next_index, block)]);
2224
2225 let loop_while = self.builder.build_int_compare(
2227 IntPredicate::ULT,
2228 next_index,
2229 thread_end.unwrap_or(end_index),
2230 name,
2231 )?;
2232 let post_block = self.context.append_basic_block(self.fn_value(), name);
2233 self.builder
2234 .build_conditional_branch(loop_while, block, post_block)?;
2235 self.builder.position_at_end(post_block);
2236
2237 if self.threaded {
2238 self.jit_end_threading(
2239 thread_start.unwrap(),
2240 thread_end.unwrap(),
2241 test_done.unwrap(),
2242 next.unwrap(),
2243 )?;
2244 }
2245
2246 Ok(())
2247 }
2248
2249 fn jit_compile_broadcast_and_store(
2250 &mut self,
2251 name: &str,
2252 elmt: &TensorBlock,
2253 float_value: FloatValue<'ctx>,
2254 expr_index: IntValue<'ctx>,
2255 translation: &Translation,
2256 pre_block: BasicBlock<'ctx>,
2257 ) -> Result<BasicBlock<'ctx>> {
2258 let int_type = self.int_type;
2259 let one = int_type.const_int(1, false);
2260 let zero = int_type.const_int(0, false);
2261 match translation.source {
2262 TranslationFrom::Broadcast {
2263 broadcast_by: _,
2264 broadcast_len,
2265 } => {
2266 let bcast_start_index = zero;
2267 let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2268
2269 let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2271 self.builder.build_unconditional_branch(bcast_block)?;
2272 self.builder.position_at_end(bcast_block);
2273 let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2274 bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2275
2276 let store_index = self.builder.build_int_add(
2278 self.builder
2279 .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2280 bcast_index.as_basic_value().into_int_value(),
2281 "bcast_store_index",
2282 )?;
2283 self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2284
2285 let bcast_next_index = self.builder.build_int_add(
2287 bcast_index.as_basic_value().into_int_value(),
2288 one,
2289 name,
2290 )?;
2291 bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2292
2293 let bcast_cond = self.builder.build_int_compare(
2295 IntPredicate::ULT,
2296 bcast_next_index,
2297 bcast_end_index,
2298 "broadcast_cond",
2299 )?;
2300 let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2301 self.builder
2302 .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2303 self.builder.position_at_end(post_bcast_block);
2304
2305 Ok(post_bcast_block)
2307 }
2308 TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2309 self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2310 Ok(pre_block)
2311 }
2312 _ => Err(anyhow!("Invalid translation")),
2313 }
2314 }
2315
2316 fn jit_compile_store(
2317 &mut self,
2318 name: &str,
2319 elmt: &TensorBlock,
2320 store_index: IntValue<'ctx>,
2321 float_value: FloatValue<'ctx>,
2322 translation: &Translation,
2323 ) -> Result<()> {
2324 let int_type = self.int_type;
2325 let res_index = match &translation.target {
2326 TranslationTo::Contiguous { start, end: _ } => {
2327 let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2328 self.builder.build_int_add(start_const, store_index, name)?
2329 }
2330 TranslationTo::Sparse { indices: _ } => {
2331 let translate_index = self
2333 .layout
2334 .get_translation_index(elmt.expr_layout(), elmt.layout())
2335 .unwrap();
2336 let translate_store_index =
2337 translate_index + translation.get_to_index_in_data_layout();
2338 let translate_store_index =
2339 int_type.const_int(translate_store_index.try_into().unwrap(), false);
2340 let elmt_index_strided = store_index;
2341 let curr_index =
2342 self.builder
2343 .build_int_add(elmt_index_strided, translate_store_index, name)?;
2344 let ptr = Self::get_ptr_to_index(
2345 &self.builder,
2346 self.int_type,
2347 self.get_param("indices"),
2348 curr_index,
2349 name,
2350 );
2351 self.build_load(self.int_type, ptr, name)?.into_int_value()
2352 }
2353 };
2354
2355 let resi_ptr = Self::get_ptr_to_index(
2356 &self.builder,
2357 self.real_type,
2358 &self.tensor_ptr(),
2359 res_index,
2360 name,
2361 );
2362 self.builder.build_store(resi_ptr, float_value)?;
2363 Ok(())
2364 }
2365
2366 fn jit_compile_expr(
2367 &mut self,
2368 name: &str,
2369 expr: &Ast,
2370 index: &[IntValue<'ctx>],
2371 elmt: &TensorBlock,
2372 expr_index: Option<IntValue<'ctx>>,
2373 ) -> Result<FloatValue<'ctx>> {
2374 let name = elmt.name().unwrap_or(name);
2375 match &expr.kind {
2376 AstKind::Binop(binop) => {
2377 let lhs =
2378 self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2379 let rhs =
2380 self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2381 match binop.op {
2382 '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2383 '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2384 '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2385 '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2386 unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2387 }
2388 }
2389 AstKind::Monop(monop) => {
2390 let child =
2391 self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2392 match monop.op {
2393 '-' => Ok(self.builder.build_float_neg(child, name)?),
2394 unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2395 }
2396 }
2397 AstKind::Call(call) => match self.get_function(call.fn_name) {
2398 Some(function) => {
2399 let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2400 for arg in call.args.iter() {
2401 let arg_val =
2402 self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2403 args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2404 }
2405 let ret_value = self
2406 .builder
2407 .build_call(function, args.as_slice(), name)?
2408 .try_as_basic_value()
2409 .left()
2410 .unwrap()
2411 .into_float_value();
2412 Ok(ret_value)
2413 }
2414 None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2415 },
2416 AstKind::CallArg(arg) => {
2417 self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2418 }
2419 AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2420 AstKind::Name(iname) => {
2421 let ptr = self.get_param(iname.name);
2422 let layout = self.layout.get_layout(iname.name).unwrap();
2423 let iname_elmt_index = if layout.is_dense() {
2424 let mut no_transform = true;
2426 let mut iname_index = Vec::new();
2427 for (i, c) in iname.indices.iter().enumerate() {
2428 let pi = elmt
2431 .indices()
2432 .iter()
2433 .position(|x| x == c)
2434 .unwrap_or(elmt.indices().len());
2435 iname_index.push(index[pi]);
2436 no_transform = no_transform && pi == i;
2437 }
2438 if !iname_index.is_empty() {
2441 let mut iname_elmt_index = *iname_index.last().unwrap();
2442 let mut stride = 1u64;
2443 for i in (0..iname_index.len() - 1).rev() {
2444 let iname_i = iname_index[i];
2445 let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2446 stride *= shapei;
2447 let stride_intval = self.context.i32_type().const_int(stride, false);
2448 let stride_mul_i =
2449 self.builder.build_int_mul(stride_intval, iname_i, name)?;
2450 iname_elmt_index =
2451 self.builder
2452 .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2453 }
2454 Some(iname_elmt_index)
2455 } else {
2456 let zero = self.context.i32_type().const_int(0, false);
2457 Some(zero)
2458 }
2459 } else if layout.is_sparse() || layout.is_diagonal() {
2460 expr_index
2463 } else {
2464 panic!("unexpected layout");
2465 };
2466 let value_ptr = match iname_elmt_index {
2467 Some(index) => {
2468 Self::get_ptr_to_index(&self.builder, self.real_type, ptr, index, name)
2469 }
2470 None => *ptr,
2471 };
2472 Ok(self
2473 .build_load(self.real_type, value_ptr, name)?
2474 .into_float_value())
2475 }
2476 AstKind::NamedGradient(name) => {
2477 let name_str = name.to_string();
2478 let ptr = self.get_param(name_str.as_str());
2479 Ok(self
2480 .build_load(self.real_type, *ptr, name_str.as_str())?
2481 .into_float_value())
2482 }
2483 AstKind::Index(_) => todo!(),
2484 AstKind::Slice(_) => todo!(),
2485 AstKind::Integer(_) => todo!(),
2486 _ => panic!("unexprected astkind"),
2487 }
2488 }
2489
2490 fn clear(&mut self) {
2491 self.variables.clear();
2492 self.fn_value_opt = None;
2494 self.tensor_ptr_opt = None;
2495 }
2496
2497 fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2498 match arg {
2499 BasicValueEnum::PointerValue(v) => v,
2500 BasicValueEnum::FloatValue(v) => {
2501 let alloca = self
2502 .create_entry_block_builder()
2503 .build_alloca(arg.get_type(), name)
2504 .unwrap();
2505 self.builder.build_store(alloca, v).unwrap();
2506 alloca
2507 }
2508 BasicValueEnum::IntValue(v) => {
2509 let alloca = self
2510 .create_entry_block_builder()
2511 .build_alloca(arg.get_type(), name)
2512 .unwrap();
2513 self.builder.build_store(alloca, v).unwrap();
2514 alloca
2515 }
2516 _ => unreachable!(),
2517 }
2518 }
2519
2520 pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2521 self.clear();
2522 let void_type = self.context.void_type();
2523 let fn_type = void_type.fn_type(
2524 &[
2525 self.real_ptr_type.into(),
2526 self.real_ptr_type.into(),
2527 self.int_type.into(),
2528 self.int_type.into(),
2529 ],
2530 false,
2531 );
2532 let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2533 let function = self.module.add_function("set_u0", fn_type, None);
2534
2535 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2537 let noalign = self.context.create_enum_attribute(alias_id, 0);
2538 for i in &[0, 1] {
2539 function.add_attribute(AttributeLoc::Param(*i), noalign);
2540 }
2541
2542 let basic_block = self.context.append_basic_block(function, "entry");
2543 self.fn_value_opt = Some(function);
2544 self.builder.position_at_end(basic_block);
2545
2546 for (i, arg) in function.get_param_iter().enumerate() {
2547 let name = fn_arg_names[i];
2548 let alloca = self.function_arg_alloca(name, arg);
2549 self.insert_param(name, alloca);
2550 }
2551
2552 self.insert_data(model);
2553 self.insert_indices();
2554
2555 let mut nbarriers = 0;
2556 let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2557 #[allow(clippy::explicit_counter_loop)]
2558 for a in model.input_dep_defns() {
2559 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2560 self.jit_compile_call_barrier(nbarriers, total_barriers);
2561 nbarriers += 1;
2562 }
2563
2564 self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?;
2565 self.jit_compile_call_barrier(nbarriers, total_barriers);
2566
2567 self.builder.build_return(None)?;
2568
2569 if function.verify(true) {
2570 Ok(function)
2571 } else {
2572 function.print_to_stderr();
2573 unsafe {
2574 function.delete();
2575 }
2576 Err(anyhow!("Invalid generated function."))
2577 }
2578 }
2579
2580 pub fn compile_calc_out<'m>(
2581 &mut self,
2582 model: &'m DiscreteModel,
2583 include_constants: bool,
2584 ) -> Result<FunctionValue<'ctx>> {
2585 self.clear();
2586 let void_type = self.context.void_type();
2587 let fn_type = void_type.fn_type(
2588 &[
2589 self.real_type.into(),
2590 self.real_ptr_type.into(),
2591 self.real_ptr_type.into(),
2592 self.real_ptr_type.into(),
2593 self.int_type.into(),
2594 self.int_type.into(),
2595 ],
2596 false,
2597 );
2598 let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2599 let function_name = if include_constants {
2600 "calc_out_full"
2601 } else {
2602 "calc_out"
2603 };
2604 let function = self.module.add_function(function_name, fn_type, None);
2605
2606 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2608 let noalign = self.context.create_enum_attribute(alias_id, 0);
2609 for i in &[1, 2] {
2610 function.add_attribute(AttributeLoc::Param(*i), noalign);
2611 }
2612
2613 let basic_block = self.context.append_basic_block(function, "entry");
2614 self.fn_value_opt = Some(function);
2615 self.builder.position_at_end(basic_block);
2616
2617 for (i, arg) in function.get_param_iter().enumerate() {
2618 let name = fn_arg_names[i];
2619 let alloca = self.function_arg_alloca(name, arg);
2620 self.insert_param(name, alloca);
2621 }
2622
2623 self.insert_state(model.state());
2624 self.insert_data(model);
2625 self.insert_indices();
2626
2627 if let Some(out) = model.out() {
2633 let mut nbarriers = 0;
2634 let mut total_barriers =
2635 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2636 if include_constants {
2637 total_barriers += model.input_dep_defns().len() as u64;
2638 for tensor in model.input_dep_defns() {
2640 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2641 self.jit_compile_call_barrier(nbarriers, total_barriers);
2642 nbarriers += 1;
2643 }
2644 }
2645
2646 for tensor in model.time_dep_defns() {
2648 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2649 self.jit_compile_call_barrier(nbarriers, total_barriers);
2650 nbarriers += 1;
2651 }
2652
2653 #[allow(clippy::explicit_counter_loop)]
2655 for a in model.state_dep_defns() {
2656 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2657 self.jit_compile_call_barrier(nbarriers, total_barriers);
2658 nbarriers += 1;
2659 }
2660
2661 self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())))?;
2662 self.jit_compile_call_barrier(nbarriers, total_barriers);
2663 }
2664 self.builder.build_return(None)?;
2665
2666 if function.verify(true) {
2667 Ok(function)
2668 } else {
2669 function.print_to_stderr();
2670 unsafe {
2671 function.delete();
2672 }
2673 Err(anyhow!("Invalid generated function."))
2674 }
2675 }
2676
2677 pub fn compile_calc_stop<'m>(
2678 &mut self,
2679 model: &'m DiscreteModel,
2680 ) -> Result<FunctionValue<'ctx>> {
2681 self.clear();
2682 let void_type = self.context.void_type();
2683 let fn_type = void_type.fn_type(
2684 &[
2685 self.real_type.into(),
2686 self.real_ptr_type.into(),
2687 self.real_ptr_type.into(),
2688 self.real_ptr_type.into(),
2689 self.int_type.into(),
2690 self.int_type.into(),
2691 ],
2692 false,
2693 );
2694 let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2695 let function = self.module.add_function("calc_stop", fn_type, None);
2696
2697 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2699 let noalign = self.context.create_enum_attribute(alias_id, 0);
2700 for i in &[1, 2, 3] {
2701 function.add_attribute(AttributeLoc::Param(*i), noalign);
2702 }
2703
2704 let basic_block = self.context.append_basic_block(function, "entry");
2705 self.fn_value_opt = Some(function);
2706 self.builder.position_at_end(basic_block);
2707
2708 for (i, arg) in function.get_param_iter().enumerate() {
2709 let name = fn_arg_names[i];
2710 let alloca = self.function_arg_alloca(name, arg);
2711 self.insert_param(name, alloca);
2712 }
2713
2714 self.insert_state(model.state());
2715 self.insert_data(model);
2716 self.insert_indices();
2717
2718 if let Some(stop) = model.stop() {
2719 let mut nbarriers = 0;
2721 let total_barriers =
2722 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2723 for tensor in model.time_dep_defns() {
2724 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2725 self.jit_compile_call_barrier(nbarriers, total_barriers);
2726 nbarriers += 1;
2727 }
2728
2729 for a in model.state_dep_defns() {
2731 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2732 self.jit_compile_call_barrier(nbarriers, total_barriers);
2733 nbarriers += 1;
2734 }
2735
2736 let res_ptr = self.get_param("root");
2737 self.jit_compile_tensor(stop, Some(*res_ptr))?;
2738 self.jit_compile_call_barrier(nbarriers, total_barriers);
2739 }
2740 self.builder.build_return(None)?;
2741
2742 if function.verify(true) {
2743 Ok(function)
2744 } else {
2745 function.print_to_stderr();
2746 unsafe {
2747 function.delete();
2748 }
2749 Err(anyhow!("Invalid generated function."))
2750 }
2751 }
2752
2753 pub fn compile_rhs<'m>(
2754 &mut self,
2755 model: &'m DiscreteModel,
2756 include_constants: bool,
2757 ) -> Result<FunctionValue<'ctx>> {
2758 self.clear();
2759 let void_type = self.context.void_type();
2760 let fn_type = void_type.fn_type(
2761 &[
2762 self.real_type.into(),
2763 self.real_ptr_type.into(),
2764 self.real_ptr_type.into(),
2765 self.real_ptr_type.into(),
2766 self.int_type.into(),
2767 self.int_type.into(),
2768 ],
2769 false,
2770 );
2771 let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
2772 let function_name = if include_constants { "rhs_full" } else { "rhs" };
2773 let function = self.module.add_function(function_name, fn_type, None);
2774
2775 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2777 let noalign = self.context.create_enum_attribute(alias_id, 0);
2778 for i in &[1, 2, 3] {
2779 function.add_attribute(AttributeLoc::Param(*i), noalign);
2780 }
2781
2782 let basic_block = self.context.append_basic_block(function, "entry");
2783 self.fn_value_opt = Some(function);
2784 self.builder.position_at_end(basic_block);
2785
2786 for (i, arg) in function.get_param_iter().enumerate() {
2787 let name = fn_arg_names[i];
2788 let alloca = self.function_arg_alloca(name, arg);
2789 self.insert_param(name, alloca);
2790 }
2791
2792 self.insert_state(model.state());
2793 self.insert_data(model);
2794 self.insert_indices();
2795
2796 let mut nbarriers = 0;
2797 let mut total_barriers =
2798 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2799 if include_constants {
2800 total_barriers += model.input_dep_defns().len() as u64;
2801 for tensor in model.input_dep_defns() {
2803 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2804 self.jit_compile_call_barrier(nbarriers, total_barriers);
2805 nbarriers += 1;
2806 }
2807 }
2808
2809 for tensor in model.time_dep_defns() {
2811 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2812 self.jit_compile_call_barrier(nbarriers, total_barriers);
2813 nbarriers += 1;
2814 }
2815
2816 for a in model.state_dep_defns() {
2818 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2819 self.jit_compile_call_barrier(nbarriers, total_barriers);
2820 nbarriers += 1;
2821 }
2822
2823 let res_ptr = self.get_param("rr");
2825 self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?;
2826 self.jit_compile_call_barrier(nbarriers, total_barriers);
2827
2828 self.builder.build_return(None)?;
2829
2830 if function.verify(true) {
2831 Ok(function)
2832 } else {
2833 function.print_to_stderr();
2834 unsafe {
2835 function.delete();
2836 }
2837 Err(anyhow!("Invalid generated function."))
2838 }
2839 }
2840
2841 pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2842 self.clear();
2843 let void_type = self.context.void_type();
2844 let fn_type = void_type.fn_type(
2845 &[
2846 self.real_type.into(),
2847 self.real_ptr_type.into(),
2848 self.real_ptr_type.into(),
2849 self.real_ptr_type.into(),
2850 self.int_type.into(),
2851 self.int_type.into(),
2852 ],
2853 false,
2854 );
2855 let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
2856 let function = self.module.add_function("mass", fn_type, None);
2857
2858 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2860 let noalign = self.context.create_enum_attribute(alias_id, 0);
2861 for i in &[1, 2, 3] {
2862 function.add_attribute(AttributeLoc::Param(*i), noalign);
2863 }
2864
2865 let basic_block = self.context.append_basic_block(function, "entry");
2866 self.fn_value_opt = Some(function);
2867 self.builder.position_at_end(basic_block);
2868
2869 for (i, arg) in function.get_param_iter().enumerate() {
2870 let name = fn_arg_names[i];
2871 let alloca = self.function_arg_alloca(name, arg);
2872 self.insert_param(name, alloca);
2873 }
2874
2875 if model.state_dot().is_some() && model.lhs().is_some() {
2877 let state_dot = model.state_dot().unwrap();
2878 let lhs = model.lhs().unwrap();
2879
2880 self.insert_dot_state(state_dot);
2881 self.insert_data(model);
2882 self.insert_indices();
2883
2884 let mut nbarriers = 0;
2886 let total_barriers =
2887 (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
2888 for tensor in model.time_dep_defns() {
2889 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2890 self.jit_compile_call_barrier(nbarriers, total_barriers);
2891 nbarriers += 1;
2892 }
2893
2894 for a in model.dstate_dep_defns() {
2895 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2896 self.jit_compile_call_barrier(nbarriers, total_barriers);
2897 nbarriers += 1;
2898 }
2899
2900 let res_ptr = self.get_param("rr");
2902 self.jit_compile_tensor(lhs, Some(*res_ptr))?;
2903 self.jit_compile_call_barrier(nbarriers, total_barriers);
2904 }
2905
2906 self.builder.build_return(None)?;
2907
2908 if function.verify(true) {
2909 Ok(function)
2910 } else {
2911 function.print_to_stderr();
2912 unsafe {
2913 function.delete();
2914 }
2915 Err(anyhow!("Invalid generated function."))
2916 }
2917 }
2918
2919 pub fn compile_gradient(
2920 &mut self,
2921 original_function: FunctionValue<'ctx>,
2922 args_type: &[CompileGradientArgType],
2923 mode: CompileMode,
2924 fn_name: &str,
2925 ) -> Result<FunctionValue<'ctx>> {
2926 self.clear();
2927
2928 let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
2930
2931 let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
2932
2933 let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
2934 let mut start_param_index: Vec<u32> = Vec::new();
2935 let mut ptr_arg_indices: Vec<u32> = Vec::new();
2936 for (i, arg) in original_function.get_param_iter().enumerate() {
2937 start_param_index.push(u32::try_from(fn_type.len()).unwrap());
2938 let arg_type = arg.get_type();
2939 fn_type.push(arg_type.into());
2940
2941 enzyme_fn_type.push(self.int_type.into());
2943 enzyme_fn_type.push(arg.get_type().into());
2944
2945 if arg_type.is_pointer_type() {
2946 ptr_arg_indices.push(u32::try_from(i).unwrap());
2947 }
2948
2949 match args_type[i] {
2950 CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
2951 fn_type.push(arg.get_type().into());
2952 enzyme_fn_type.push(arg.get_type().into());
2953 }
2954 CompileGradientArgType::Const => {}
2955 }
2956 }
2957 let void_type = self.context.void_type();
2958 let fn_type = void_type.fn_type(fn_type.as_slice(), false);
2959 let function = self.module.add_function(fn_name, fn_type, None);
2960
2961 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2963 let noalign = self.context.create_enum_attribute(alias_id, 0);
2964 for i in ptr_arg_indices {
2965 function.add_attribute(AttributeLoc::Param(i), noalign);
2966 }
2967
2968 let basic_block = self.context.append_basic_block(function, "entry");
2969 self.fn_value_opt = Some(function);
2970 self.builder.position_at_end(basic_block);
2971
2972 let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
2973 let mut input_activity = Vec::new();
2974 let mut arg_trees = Vec::new();
2975 for (i, arg) in original_function.get_param_iter().enumerate() {
2976 let param_index = start_param_index[i];
2977 let fn_arg = function.get_nth_param(param_index).unwrap();
2978
2979 let concrete_type = match arg.get_type() {
2982 BasicTypeEnum::PointerType(_) => CConcreteType_DT_Pointer,
2983 BasicTypeEnum::FloatType(t) => {
2984 if t == self.context.f64_type() {
2985 CConcreteType_DT_Double
2986 } else {
2987 panic!("unsupported type")
2988 }
2989 }
2990 BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
2991 _ => panic!("unsupported type"),
2992 };
2993 let new_tree = unsafe {
2994 EnzymeNewTypeTreeCT(
2995 concrete_type,
2996 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
2997 )
2998 };
2999 unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3000
3001 if concrete_type == CConcreteType_DT_Pointer {
3003 let inner_concrete_type = CConcreteType_DT_Double;
3005 let inner_new_tree = unsafe {
3006 EnzymeNewTypeTreeCT(
3007 inner_concrete_type,
3008 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3009 )
3010 };
3011 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3012 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3013 unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
3014 }
3015
3016 arg_trees.push(new_tree);
3017 match args_type[i] {
3018 CompileGradientArgType::Dup => {
3019 enzyme_fn_args.push(fn_arg.into());
3021
3022 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3024 enzyme_fn_args.push(fn_darg.into());
3025
3026 input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3027 }
3028 CompileGradientArgType::DupNoNeed => {
3029 enzyme_fn_args.push(fn_arg.into());
3031
3032 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3034 enzyme_fn_args.push(fn_darg.into());
3035
3036 input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3037 }
3038 CompileGradientArgType::Const => {
3039 enzyme_fn_args.push(fn_arg.into());
3041
3042 input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3043 }
3044 }
3045 }
3046 let ret_primary_ret = false;
3048 let diff_ret = false;
3049 let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3050 let ret_tree = unsafe {
3051 EnzymeNewTypeTreeCT(
3052 CConcreteType_DT_Anything,
3053 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3054 )
3055 };
3056
3057 let fnc_opt_base = true;
3059 let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3060
3061 let kv_tmp = IntList {
3062 data: std::ptr::null_mut(),
3063 size: 0,
3064 };
3065 let mut known_values = vec![kv_tmp; input_activity.len()];
3066
3067 let fn_type_info = CFnTypeInfo {
3068 Arguments: arg_trees.as_mut_ptr(),
3069 Return: ret_tree,
3070 KnownValues: known_values.as_mut_ptr(),
3071 };
3072
3073 let type_analysis: EnzymeTypeAnalysisRef =
3074 unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3075
3076 let mut args_uncacheable = vec![0; arg_trees.len()];
3077
3078 let enzyme_function = match mode {
3079 CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3080 EnzymeCreateForwardDiff(
3081 logic_ref, std::ptr::null_mut(),
3083 std::ptr::null_mut(),
3084 original_function.as_value_ref(),
3085 ret_activity, input_activity.as_mut_ptr(),
3087 input_activity.len(), type_analysis, ret_primary_ret as u8,
3090 CDerivativeMode_DEM_ForwardMode, 1, 0, 1, std::ptr::null_mut(),
3095 fn_type_info, args_uncacheable.as_mut_ptr(),
3097 args_uncacheable.len(), std::ptr::null_mut(), )
3100 },
3101 CompileMode::Reverse | CompileMode::ReverseSens => {
3102 let mut call_enzyme = || unsafe {
3103 EnzymeCreatePrimalAndGradient(
3104 logic_ref,
3105 std::ptr::null_mut(),
3106 std::ptr::null_mut(),
3107 original_function.as_value_ref(),
3108 ret_activity,
3109 input_activity.as_mut_ptr(),
3110 input_activity.len(),
3111 type_analysis,
3112 ret_primary_ret as u8,
3113 diff_ret as u8,
3114 CDerivativeMode_DEM_ReverseModeCombined,
3115 0,
3116 1,
3117 1,
3118 std::ptr::null_mut(),
3119 0,
3120 fn_type_info,
3121 args_uncacheable.as_mut_ptr(),
3122 args_uncacheable.len(),
3123 std::ptr::null_mut(),
3124 if self.threaded { 1 } else { 0 },
3125 )
3126 };
3127 if self.threaded {
3128 let _lock = my_mutex.lock().unwrap();
3130 let barrier_string = CString::new("barrier").unwrap();
3131 unsafe {
3132 EnzymeRegisterCallHandler(
3133 barrier_string.as_ptr(),
3134 Some(fwd_handler),
3135 Some(rev_handler),
3136 )
3137 };
3138 let ret = call_enzyme();
3139 unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3141 ret
3142 } else {
3143 call_enzyme()
3144 }
3145 }
3146 };
3147
3148 unsafe { FreeEnzymeLogic(logic_ref) };
3150 unsafe { FreeTypeAnalysis(type_analysis) };
3151 unsafe { EnzymeFreeTypeTree(ret_tree) };
3152 for tree in arg_trees {
3153 unsafe { EnzymeFreeTypeTree(tree) };
3154 }
3155
3156 let enzyme_function =
3158 unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3159 self.builder
3160 .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3161
3162 self.builder.build_return(None)?;
3164
3165 if function.verify(true) {
3166 Ok(function)
3167 } else {
3168 function.print_to_stderr();
3169 enzyme_function.print_to_stderr();
3170 unsafe {
3171 function.delete();
3172 }
3173 Err(anyhow!("Invalid generated function."))
3174 }
3175 }
3176
3177 pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3178 self.clear();
3179 let fn_type = self.context.void_type().fn_type(
3180 &[
3181 self.int_ptr_type.into(),
3182 self.int_ptr_type.into(),
3183 self.int_ptr_type.into(),
3184 self.int_ptr_type.into(),
3185 self.int_ptr_type.into(),
3186 self.int_ptr_type.into(),
3187 ],
3188 false,
3189 );
3190
3191 let function = self.module.add_function("get_dims", fn_type, None);
3192 let block = self.context.append_basic_block(function, "entry");
3193 let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3194 self.builder.position_at_end(block);
3195
3196 for (i, arg) in function.get_param_iter().enumerate() {
3197 let name = fn_arg_names[i];
3198 let alloca = self.function_arg_alloca(name, arg);
3199 self.insert_param(name, alloca);
3200 }
3201
3202 self.insert_indices();
3203
3204 let number_of_states = model.state().nnz() as u64;
3205 let number_of_inputs = model.inputs().iter().fold(0, |acc, x| acc + x.nnz()) as u64;
3206 let number_of_outputs = match model.out() {
3207 Some(out) => out.nnz() as u64,
3208 None => 0,
3209 };
3210 let number_of_stop = if let Some(stop) = model.stop() {
3211 stop.nnz() as u64
3212 } else {
3213 0
3214 };
3215 let has_mass = match model.lhs().is_some() {
3216 true => 1u64,
3217 false => 0u64,
3218 };
3219 let data_len = self.layout.data().len() as u64;
3220 self.builder.build_store(
3221 *self.get_param("states"),
3222 self.int_type.const_int(number_of_states, false),
3223 )?;
3224 self.builder.build_store(
3225 *self.get_param("inputs"),
3226 self.int_type.const_int(number_of_inputs, false),
3227 )?;
3228 self.builder.build_store(
3229 *self.get_param("outputs"),
3230 self.int_type.const_int(number_of_outputs, false),
3231 )?;
3232 self.builder.build_store(
3233 *self.get_param("data"),
3234 self.int_type.const_int(data_len, false),
3235 )?;
3236 self.builder.build_store(
3237 *self.get_param("stop"),
3238 self.int_type.const_int(number_of_stop, false),
3239 )?;
3240 self.builder.build_store(
3241 *self.get_param("has_mass"),
3242 self.int_type.const_int(has_mass, false),
3243 )?;
3244 self.builder.build_return(None)?;
3245
3246 if function.verify(true) {
3247 Ok(function)
3248 } else {
3249 function.print_to_stderr();
3250 unsafe {
3251 function.delete();
3252 }
3253 Err(anyhow!("Invalid generated function."))
3254 }
3255 }
3256
3257 pub fn compile_get_tensor(
3258 &mut self,
3259 model: &DiscreteModel,
3260 name: &str,
3261 ) -> Result<FunctionValue<'ctx>> {
3262 self.clear();
3263 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3264 let fn_type = self.context.void_type().fn_type(
3265 &[
3266 self.real_ptr_type.into(),
3267 real_ptr_ptr_type.into(),
3268 self.int_ptr_type.into(),
3269 ],
3270 false,
3271 );
3272 let function_name = format!("get_tensor_{name}");
3273 let function = self
3274 .module
3275 .add_function(function_name.as_str(), fn_type, None);
3276 let basic_block = self.context.append_basic_block(function, "entry");
3277 self.fn_value_opt = Some(function);
3278
3279 let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3280 self.builder.position_at_end(basic_block);
3281
3282 for (i, arg) in function.get_param_iter().enumerate() {
3283 let name = fn_arg_names[i];
3284 let alloca = self.function_arg_alloca(name, arg);
3285 self.insert_param(name, alloca);
3286 }
3287
3288 self.insert_data(model);
3289 let ptr = self.get_param(name);
3290 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3291 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3292 self.builder
3293 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3294 self.builder
3295 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3296 self.builder.build_return(None)?;
3297
3298 if function.verify(true) {
3299 Ok(function)
3300 } else {
3301 function.print_to_stderr();
3302 unsafe {
3303 function.delete();
3304 }
3305 Err(anyhow!("Invalid generated function."))
3306 }
3307 }
3308
3309 pub fn compile_get_constant(
3310 &mut self,
3311 model: &DiscreteModel,
3312 name: &str,
3313 ) -> Result<FunctionValue<'ctx>> {
3314 self.clear();
3315 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3316 let fn_type = self
3317 .context
3318 .void_type()
3319 .fn_type(&[real_ptr_ptr_type.into(), self.int_ptr_type.into()], false);
3320 let function_name = format!("get_constant_{name}");
3321 let function = self
3322 .module
3323 .add_function(function_name.as_str(), fn_type, None);
3324 let basic_block = self.context.append_basic_block(function, "entry");
3325 self.fn_value_opt = Some(function);
3326
3327 let fn_arg_names = &["tensor_data", "tensor_size"];
3328 self.builder.position_at_end(basic_block);
3329
3330 for (i, arg) in function.get_param_iter().enumerate() {
3331 let name = fn_arg_names[i];
3332 let alloca = self.function_arg_alloca(name, arg);
3333 self.insert_param(name, alloca);
3334 }
3335
3336 self.insert_constants(model);
3337 let ptr = self.get_param(name);
3338 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3339 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3340 self.builder
3341 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3342 self.builder
3343 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3344 self.builder.build_return(None)?;
3345
3346 if function.verify(true) {
3347 Ok(function)
3348 } else {
3349 function.print_to_stderr();
3350 unsafe {
3351 function.delete();
3352 }
3353 Err(anyhow!("Invalid generated function."))
3354 }
3355 }
3356
3357 pub fn compile_inputs(
3358 &mut self,
3359 model: &DiscreteModel,
3360 is_get: bool,
3361 ) -> Result<FunctionValue<'ctx>> {
3362 self.clear();
3363 let void_type = self.context.void_type();
3364 let fn_type = void_type.fn_type(
3365 &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3366 false,
3367 );
3368 let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3369 let function = self.module.add_function(function_name, fn_type, None);
3370 let mut block = self.context.append_basic_block(function, "entry");
3371 self.fn_value_opt = Some(function);
3372
3373 let fn_arg_names = &["inputs", "data"];
3374 self.builder.position_at_end(block);
3375
3376 for (i, arg) in function.get_param_iter().enumerate() {
3377 let name = fn_arg_names[i];
3378 let alloca = self.function_arg_alloca(name, arg);
3379 self.insert_param(name, alloca);
3380 }
3381
3382 let mut inputs_index = 0usize;
3383 for input in model.inputs() {
3384 let name = format!("input_{}", input.name());
3385 self.insert_tensor(input, false);
3386 let ptr = self.get_var(input);
3387 let inputs_start_index = self.int_type.const_int(inputs_index as u64, false);
3389 let start_index = self.int_type.const_int(0, false);
3390 let end_index = self
3391 .int_type
3392 .const_int(input.nnz().try_into().unwrap(), false);
3393
3394 let input_block = self.context.append_basic_block(function, name.as_str());
3395 self.builder.build_unconditional_branch(input_block)?;
3396 self.builder.position_at_end(input_block);
3397 let index = self.builder.build_phi(self.int_type, "i")?;
3398 index.add_incoming(&[(&start_index, block)]);
3399
3400 let curr_input_index = index.as_basic_value().into_int_value();
3402 let input_ptr = Self::get_ptr_to_index(
3403 &self.builder,
3404 self.real_type,
3405 ptr,
3406 curr_input_index,
3407 name.as_str(),
3408 );
3409 let curr_inputs_index =
3410 self.builder
3411 .build_int_add(inputs_start_index, curr_input_index, name.as_str())?;
3412 let inputs_ptr = Self::get_ptr_to_index(
3413 &self.builder,
3414 self.real_type,
3415 self.get_param("inputs"),
3416 curr_inputs_index,
3417 name.as_str(),
3418 );
3419 if is_get {
3420 let input_value = self
3421 .build_load(self.real_type, input_ptr, name.as_str())?
3422 .into_float_value();
3423 self.builder.build_store(inputs_ptr, input_value)?;
3424 } else {
3425 let input_value = self
3426 .build_load(self.real_type, inputs_ptr, name.as_str())?
3427 .into_float_value();
3428 self.builder.build_store(input_ptr, input_value)?;
3429 }
3430
3431 let one = self.int_type.const_int(1, false);
3433 let next_index = self
3434 .builder
3435 .build_int_add(curr_input_index, one, name.as_str())?;
3436 index.add_incoming(&[(&next_index, input_block)]);
3437
3438 let loop_while = self.builder.build_int_compare(
3440 IntPredicate::ULT,
3441 next_index,
3442 end_index,
3443 name.as_str(),
3444 )?;
3445 let post_block = self.context.append_basic_block(function, name.as_str());
3446 self.builder
3447 .build_conditional_branch(loop_while, input_block, post_block)?;
3448 self.builder.position_at_end(post_block);
3449
3450 block = post_block;
3452 inputs_index += input.nnz();
3453 }
3454 self.builder.build_return(None)?;
3455
3456 if function.verify(true) {
3457 Ok(function)
3458 } else {
3459 function.print_to_stderr();
3460 unsafe {
3461 function.delete();
3462 }
3463 Err(anyhow!("Invalid generated function."))
3464 }
3465 }
3466
3467 pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3468 self.clear();
3469 let void_type = self.context.void_type();
3470 let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false);
3471 let function = self.module.add_function("set_id", fn_type, None);
3472 let mut block = self.context.append_basic_block(function, "entry");
3473
3474 let fn_arg_names = &["id"];
3475 self.builder.position_at_end(block);
3476
3477 for (i, arg) in function.get_param_iter().enumerate() {
3478 let name = fn_arg_names[i];
3479 let alloca = self.function_arg_alloca(name, arg);
3480 self.insert_param(name, alloca);
3481 }
3482
3483 let mut id_index = 0usize;
3484 for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3485 let name = blk.name().unwrap_or("unknown");
3486 let id_start_index = self.int_type.const_int(id_index as u64, false);
3488 let blk_start_index = self.int_type.const_int(0, false);
3489 let blk_end_index = self
3490 .int_type
3491 .const_int(blk.nnz().try_into().unwrap(), false);
3492
3493 let blk_block = self.context.append_basic_block(function, name);
3494 self.builder.build_unconditional_branch(blk_block)?;
3495 self.builder.position_at_end(blk_block);
3496 let index = self.builder.build_phi(self.int_type, "i")?;
3497 index.add_incoming(&[(&blk_start_index, block)]);
3498
3499 let curr_blk_index = index.as_basic_value().into_int_value();
3501 let curr_id_index = self
3502 .builder
3503 .build_int_add(id_start_index, curr_blk_index, name)?;
3504 let id_ptr = Self::get_ptr_to_index(
3505 &self.builder,
3506 self.real_type,
3507 self.get_param("id"),
3508 curr_id_index,
3509 name,
3510 );
3511 let is_algebraic_float = if *is_algebraic {
3512 0.0 as RealType
3513 } else {
3514 1.0 as RealType
3515 };
3516 let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3517 self.builder.build_store(id_ptr, is_algebraic_value)?;
3518
3519 let one = self.int_type.const_int(1, false);
3521 let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3522 index.add_incoming(&[(&next_index, blk_block)]);
3523
3524 let loop_while = self.builder.build_int_compare(
3526 IntPredicate::ULT,
3527 next_index,
3528 blk_end_index,
3529 name,
3530 )?;
3531 let post_block = self.context.append_basic_block(function, name);
3532 self.builder
3533 .build_conditional_branch(loop_while, blk_block, post_block)?;
3534 self.builder.position_at_end(post_block);
3535
3536 block = post_block;
3538 id_index += blk.nnz();
3539 }
3540 self.builder.build_return(None)?;
3541
3542 if function.verify(true) {
3543 Ok(function)
3544 } else {
3545 function.print_to_stderr();
3546 unsafe {
3547 function.delete();
3548 }
3549 Err(anyhow!("Invalid generated function."))
3550 }
3551 }
3552
3553 pub fn module(&self) -> &Module<'ctx> {
3554 &self.module
3555 }
3556}