1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use inkwell::attributes::{Attribute, AttributeLoc};
4use inkwell::basic_block::BasicBlock;
5use inkwell::builder::Builder;
6use inkwell::context::{AsContextRef, Context};
7use inkwell::execution_engine::ExecutionEngine;
8use inkwell::intrinsics::Intrinsic;
9use inkwell::module::{Linkage, Module};
10use inkwell::passes::PassBuilderOptions;
11use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
12use inkwell::types::{
13 BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
14};
15use inkwell::values::{
16 AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
17 FunctionValue, GlobalValue, IntValue, PointerValue,
18};
19use inkwell::{
20 AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
21 OptimizationLevel,
22};
23use llvm_sys::core::{
24 LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
25 LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
26};
27use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
28use std::collections::HashMap;
29use std::ffi::CString;
30use std::iter::zip;
31use std::pin::Pin;
32use target_lexicon::Triple;
33
34use crate::ast::{Ast, AstKind};
35use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
36use crate::enzyme::{
37 CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
38 CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
39 CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
40 DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
41 EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
42 EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
43 FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
44 CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
45};
46use crate::execution::compiler::CompilerMode;
47use crate::execution::module::{
48 CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
49};
50use crate::execution::scalar::RealType;
51use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
52use lazy_static::lazy_static;
53use std::sync::Mutex;
54
55lazy_static! {
56 static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
57}
58
59struct ImmovableLlvmModule {
60 codegen: Option<CodeGen<'static>>,
63 context: AliasableBox<Context>,
65 _pin: std::marker::PhantomPinned,
66}
67
68pub struct LlvmModule {
69 inner: Pin<Box<ImmovableLlvmModule>>,
70 machine: TargetMachine,
71}
72
73unsafe impl Send for LlvmModule {}
74unsafe impl Sync for LlvmModule {}
75
76impl LlvmModule {
77 fn new(
78 triple: Option<Triple>,
79 model: &DiscreteModel,
80 threaded: bool,
81 real_type: RealType,
82 ) -> Result<Self> {
83 let initialization_config = &InitializationConfig::default();
84 Target::initialize_all(initialization_config);
85 let host_triple = Triple::host();
86 let (triple_str, native) = match triple {
87 Some(ref triple) => (triple.to_string(), false),
88 None => (host_triple.to_string(), true),
89 };
90 let triple = TargetTriple::create(triple_str.as_str());
91 let target = Target::from_triple(&triple).unwrap();
92 let cpu = if native {
93 TargetMachine::get_host_cpu_name().to_string()
94 } else {
95 "generic".to_string()
96 };
97 let features = if native {
98 TargetMachine::get_host_cpu_features().to_string()
99 } else {
100 "".to_string()
101 };
102 let machine = target
103 .create_target_machine(
104 &triple,
105 cpu.as_str(),
106 features.as_str(),
107 inkwell::OptimizationLevel::Aggressive,
108 inkwell::targets::RelocMode::Default,
109 inkwell::targets::CodeModel::Default,
110 )
111 .unwrap();
112
113 let context = AliasableBox::from_unique(Box::new(Context::create()));
114 let mut pinned = Self {
115 inner: Box::pin(ImmovableLlvmModule {
116 codegen: None,
117 context,
118 _pin: std::marker::PhantomPinned,
119 }),
120 machine,
121 };
122
123 let context_ref = pinned.inner.context.as_ref();
124 let real_type_llvm = match real_type {
125 RealType::F32 => context_ref.f32_type(),
126 RealType::F64 => context_ref.f64_type(),
127 };
128 let codegen = CodeGen::new(
129 model,
130 context_ref,
131 real_type,
132 real_type_llvm,
133 context_ref.i32_type(),
134 threaded,
135 )?;
136 let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
137 unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
138 Ok(pinned)
139 }
140
141 fn pre_autodiff_optimisation(&mut self) -> Result<()> {
142 let pass_options = PassBuilderOptions::create();
151 pass_options.set_loop_vectorization(false);
155 pass_options.set_loop_slp_vectorization(false);
156 pass_options.set_loop_unrolling(false);
157 let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
165 let (codegen, machine) = self.codegen_and_machine_mut();
166 codegen
167 .module()
168 .run_passes(passes, machine, pass_options)
169 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
170 }
171
172 fn post_autodiff_optimisation(&mut self) -> Result<()> {
173 if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
175 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
176 barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
177 }
178
179 for f in self.codegen_mut().module.get_functions() {
181 if f.get_name().to_str().unwrap().starts_with("preprocess_") {
182 unsafe { f.delete() };
183 }
184 }
185
186 let passes = "default<O3>";
192 let (codegen, machine) = self.codegen_and_machine_mut();
193 codegen
194 .module()
195 .run_passes(passes, machine, PassBuilderOptions::create())
196 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
197
198 Ok(())
199 }
200
201 pub fn print(&self) {
202 self.codegen().module().print_to_stderr();
203 }
204 fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
205 unsafe {
206 self.inner
207 .as_mut()
208 .get_unchecked_mut()
209 .codegen
210 .as_mut()
211 .unwrap()
212 }
213 }
214 fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
215 (
216 unsafe {
217 self.inner
218 .as_mut()
219 .get_unchecked_mut()
220 .codegen
221 .as_mut()
222 .unwrap()
223 },
224 &self.machine,
225 )
226 }
227
228 fn codegen(&self) -> &CodeGen<'static> {
229 self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
230 }
231}
232
233impl CodegenModule for LlvmModule {}
234
235impl CodegenModuleCompile for LlvmModule {
236 fn from_discrete_model(
237 model: &DiscreteModel,
238 mode: CompilerMode,
239 triple: Option<Triple>,
240 real_type: RealType,
241 ) -> Result<Self> {
242 let thread_dim = mode.thread_dim(model.state().nnz());
243 let threaded = thread_dim > 1;
244 if (unsafe { LLVMIsMultithreaded() } <= 0) {
245 return Err(anyhow!(
246 "LLVM is not compiled with multithreading support, but this codegen module requires it."
247 ));
248 }
249
250 let mut module = Self::new(triple, model, threaded, real_type)?;
251
252 let set_u0 = module.codegen_mut().compile_set_u0(model)?;
253 let _calc_stop = module.codegen_mut().compile_calc_stop(model)?;
254 let rhs = module.codegen_mut().compile_rhs(model, false)?;
255 let rhs_full = module.codegen_mut().compile_rhs(model, true)?;
256 let mass = module.codegen_mut().compile_mass(model)?;
257 let calc_out = module.codegen_mut().compile_calc_out(model, false)?;
258 let calc_out_full = module.codegen_mut().compile_calc_out(model, true)?;
259 let _set_id = module.codegen_mut().compile_set_id(model)?;
260 let _get_dims = module.codegen_mut().compile_get_dims(model)?;
261 let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
262 let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
263 let _set_constants = module.codegen_mut().compile_set_constants(model)?;
264 let tensor_info = module
265 .codegen()
266 .layout
267 .tensors()
268 .map(|(name, is_constant)| (name.to_string(), is_constant))
269 .collect::<Vec<_>>();
270 for (tensor, is_constant) in tensor_info {
271 if is_constant {
272 module
273 .codegen_mut()
274 .compile_get_constant(model, tensor.as_str())?;
275 } else {
276 module
277 .codegen_mut()
278 .compile_get_tensor(model, tensor.as_str())?;
279 }
280 }
281
282 module.pre_autodiff_optimisation()?;
283
284 module.codegen_mut().compile_gradient(
285 set_u0,
286 &[
287 CompileGradientArgType::DupNoNeed,
288 CompileGradientArgType::DupNoNeed,
289 CompileGradientArgType::Const,
290 CompileGradientArgType::Const,
291 ],
292 CompileMode::Forward,
293 "set_u0_grad",
294 )?;
295
296 module.codegen_mut().compile_gradient(
297 rhs,
298 &[
299 CompileGradientArgType::Const,
300 CompileGradientArgType::DupNoNeed,
301 CompileGradientArgType::DupNoNeed,
302 CompileGradientArgType::DupNoNeed,
303 CompileGradientArgType::Const,
304 CompileGradientArgType::Const,
305 ],
306 CompileMode::Forward,
307 "rhs_grad",
308 )?;
309
310 module.codegen_mut().compile_gradient(
311 calc_out,
312 &[
313 CompileGradientArgType::Const,
314 CompileGradientArgType::DupNoNeed,
315 CompileGradientArgType::DupNoNeed,
316 CompileGradientArgType::DupNoNeed,
317 CompileGradientArgType::Const,
318 CompileGradientArgType::Const,
319 ],
320 CompileMode::Forward,
321 "calc_out_grad",
322 )?;
323 module.codegen_mut().compile_gradient(
324 set_inputs,
325 &[
326 CompileGradientArgType::DupNoNeed,
327 CompileGradientArgType::DupNoNeed,
328 ],
329 CompileMode::Forward,
330 "set_inputs_grad",
331 )?;
332
333 module.codegen_mut().compile_gradient(
334 set_u0,
335 &[
336 CompileGradientArgType::DupNoNeed,
337 CompileGradientArgType::DupNoNeed,
338 CompileGradientArgType::Const,
339 CompileGradientArgType::Const,
340 ],
341 CompileMode::Reverse,
342 "set_u0_rgrad",
343 )?;
344
345 module.codegen_mut().compile_gradient(
346 mass,
347 &[
348 CompileGradientArgType::Const,
349 CompileGradientArgType::DupNoNeed,
350 CompileGradientArgType::DupNoNeed,
351 CompileGradientArgType::DupNoNeed,
352 CompileGradientArgType::Const,
353 CompileGradientArgType::Const,
354 ],
355 CompileMode::Reverse,
356 "mass_rgrad",
357 )?;
358
359 module.codegen_mut().compile_gradient(
360 rhs,
361 &[
362 CompileGradientArgType::Const,
363 CompileGradientArgType::DupNoNeed,
364 CompileGradientArgType::DupNoNeed,
365 CompileGradientArgType::DupNoNeed,
366 CompileGradientArgType::Const,
367 CompileGradientArgType::Const,
368 ],
369 CompileMode::Reverse,
370 "rhs_rgrad",
371 )?;
372 module.codegen_mut().compile_gradient(
373 calc_out,
374 &[
375 CompileGradientArgType::Const,
376 CompileGradientArgType::DupNoNeed,
377 CompileGradientArgType::DupNoNeed,
378 CompileGradientArgType::DupNoNeed,
379 CompileGradientArgType::Const,
380 CompileGradientArgType::Const,
381 ],
382 CompileMode::Reverse,
383 "calc_out_rgrad",
384 )?;
385
386 module.codegen_mut().compile_gradient(
387 set_inputs,
388 &[
389 CompileGradientArgType::DupNoNeed,
390 CompileGradientArgType::DupNoNeed,
391 ],
392 CompileMode::Reverse,
393 "set_inputs_rgrad",
394 )?;
395
396 module.codegen_mut().compile_gradient(
397 rhs_full,
398 &[
399 CompileGradientArgType::Const,
400 CompileGradientArgType::Const,
401 CompileGradientArgType::DupNoNeed,
402 CompileGradientArgType::DupNoNeed,
403 CompileGradientArgType::Const,
404 CompileGradientArgType::Const,
405 ],
406 CompileMode::ForwardSens,
407 "rhs_sgrad",
408 )?;
409
410 module.codegen_mut().compile_gradient(
411 set_u0,
412 &[
413 CompileGradientArgType::DupNoNeed,
414 CompileGradientArgType::DupNoNeed,
415 CompileGradientArgType::Const,
416 CompileGradientArgType::Const,
417 ],
418 CompileMode::ForwardSens,
419 "set_u0_sgrad",
420 )?;
421
422 module.codegen_mut().compile_gradient(
423 calc_out_full,
424 &[
425 CompileGradientArgType::Const,
426 CompileGradientArgType::Const,
427 CompileGradientArgType::DupNoNeed,
428 CompileGradientArgType::DupNoNeed,
429 CompileGradientArgType::Const,
430 CompileGradientArgType::Const,
431 ],
432 CompileMode::ForwardSens,
433 "calc_out_sgrad",
434 )?;
435 module.codegen_mut().compile_gradient(
436 calc_out_full,
437 &[
438 CompileGradientArgType::Const,
439 CompileGradientArgType::Const,
440 CompileGradientArgType::DupNoNeed,
441 CompileGradientArgType::DupNoNeed,
442 CompileGradientArgType::Const,
443 CompileGradientArgType::Const,
444 ],
445 CompileMode::ReverseSens,
446 "calc_out_srgrad",
447 )?;
448
449 module.codegen_mut().compile_gradient(
450 rhs_full,
451 &[
452 CompileGradientArgType::Const,
453 CompileGradientArgType::Const,
454 CompileGradientArgType::DupNoNeed,
455 CompileGradientArgType::DupNoNeed,
456 CompileGradientArgType::Const,
457 CompileGradientArgType::Const,
458 ],
459 CompileMode::ReverseSens,
460 "rhs_srgrad",
461 )?;
462
463 module.post_autodiff_optimisation()?;
464 Ok(module)
465 }
466}
467
468impl CodegenModuleEmit for LlvmModule {
469 fn to_object(self) -> Result<Vec<u8>> {
470 let module = self.codegen().module();
471 let buffer = self
473 .machine
474 .write_to_memory_buffer(module, FileType::Object)
475 .unwrap()
476 .as_slice()
477 .to_vec();
478 Ok(buffer)
479 }
480}
481impl CodegenModuleJit for LlvmModule {
482 fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
483 let ee = self
484 .codegen()
485 .module()
486 .create_jit_execution_engine(OptimizationLevel::Aggressive)
487 .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
488
489 let module = self.codegen().module();
490 let mut symbols = HashMap::new();
491 for function in module.get_functions() {
492 let name = function.get_name().to_str().unwrap();
493 let address = ee.get_function_address(name);
494 if let Ok(address) = address {
495 symbols.insert(name.to_string(), address as *const u8);
496 }
497 }
498 Ok(symbols)
499 }
500}
501
502struct Globals<'ctx> {
503 indices: Option<GlobalValue<'ctx>>,
504 constants: Option<GlobalValue<'ctx>>,
505 thread_counter: Option<GlobalValue<'ctx>>,
506}
507
508impl<'ctx> Globals<'ctx> {
509 fn new(
510 layout: &DataLayout,
511 module: &Module<'ctx>,
512 int_type: IntType<'ctx>,
513 real_type: FloatType<'ctx>,
514 threaded: bool,
515 ) -> Self {
516 let thread_counter = if threaded {
517 let tc = module.add_global(
518 int_type,
519 Some(AddressSpace::default()),
520 "enzyme_const_thread_counter",
521 );
522 let tc_value = int_type.const_zero();
529 tc.set_visibility(GlobalVisibility::Hidden);
530 tc.set_initializer(&tc_value.as_basic_value_enum());
531 Some(tc)
532 } else {
533 None
534 };
535 let constants = if layout.constants().is_empty() {
536 None
537 } else {
538 let constants_array_type =
539 real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
540 let constants = module.add_global(
541 constants_array_type,
542 Some(AddressSpace::default()),
543 "enzyme_const_constants",
544 );
545 constants.set_visibility(GlobalVisibility::Hidden);
546 constants.set_constant(false);
547 constants.set_initializer(&constants_array_type.const_zero());
548 Some(constants)
549 };
550 let indices = if layout.indices().is_empty() {
551 None
552 } else {
553 let indices_array_type =
554 int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
555 let indices_array_values = layout
556 .indices()
557 .iter()
558 .map(|&i| int_type.const_int(i as u64, true))
559 .collect::<Vec<IntValue>>();
560 let indices_value = int_type.const_array(indices_array_values.as_slice());
561 let indices = module.add_global(
562 indices_array_type,
563 Some(AddressSpace::default()),
564 "enzyme_const_indices",
565 );
566 indices.set_constant(true);
567 indices.set_visibility(GlobalVisibility::Hidden);
568 indices.set_initializer(&indices_value);
569 Some(indices)
570 };
571 Self {
572 indices,
573 thread_counter,
574 constants,
575 }
576 }
577}
578
579pub enum CompileGradientArgType {
580 Const,
581 Dup,
582 DupNoNeed,
583}
584
585pub enum CompileMode {
586 Forward,
587 ForwardSens,
588 Reverse,
589 ReverseSens,
590}
591
592pub struct CodeGen<'ctx> {
593 context: &'ctx inkwell::context::Context,
594 module: Module<'ctx>,
595 builder: Builder<'ctx>,
596 variables: HashMap<String, PointerValue<'ctx>>,
597 functions: HashMap<String, FunctionValue<'ctx>>,
598 fn_value_opt: Option<FunctionValue<'ctx>>,
599 tensor_ptr_opt: Option<PointerValue<'ctx>>,
600 diffsl_real_type: RealType,
601 real_type: FloatType<'ctx>,
602 real_ptr_type: PointerType<'ctx>,
603 int_type: IntType<'ctx>,
604 int_ptr_type: PointerType<'ctx>,
605 layout: DataLayout,
606 globals: Globals<'ctx>,
607 threaded: bool,
608 _ee: Option<ExecutionEngine<'ctx>>,
609}
610
611unsafe extern "C" fn fwd_handler(
612 _builder: LLVMBuilderRef,
613 _call_instruction: LLVMValueRef,
614 _gutils: *mut GradientUtils,
615 _dcall: *mut LLVMValueRef,
616 _normal_return: *mut LLVMValueRef,
617 _shadow_return: *mut LLVMValueRef,
618) -> u8 {
619 1
620}
621
622unsafe extern "C" fn rev_handler(
623 builder: LLVMBuilderRef,
624 call_instruction: LLVMValueRef,
625 gutils: *mut DiffeGradientUtils,
626 _tape: LLVMValueRef,
627) {
628 let call_block = LLVMGetInstructionParent(call_instruction);
629 let call_function = LLVMGetBasicBlockParent(call_block);
630 let module = LLVMGetGlobalParent(call_function);
631 let name_c_str = CString::new("barrier_grad").unwrap();
632 let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
633 let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
634 let barrier_num = LLVMGetArgOperand(call_instruction, 0);
635 let total_barriers = LLVMGetArgOperand(call_instruction, 1);
636 let thread_count = LLVMGetArgOperand(call_instruction, 2);
637 let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
638 let total_barriers =
639 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
640 let thread_count =
641 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
642 let mut args = [barrier_num, total_barriers, thread_count];
643 let name_c_str = CString::new("").unwrap();
644 LLVMBuildCall2(
645 builder,
646 barrier_func_type,
647 barrier_func,
648 args.as_mut_ptr(),
649 args.len() as u32,
650 name_c_str.as_ptr(),
651 );
652}
653
654#[allow(dead_code)]
655enum PrintValue<'ctx> {
656 Real(FloatValue<'ctx>),
657 Int(IntValue<'ctx>),
658}
659
660impl<'ctx> CodeGen<'ctx> {
661 pub fn new(
662 model: &DiscreteModel,
663 context: &'ctx inkwell::context::Context,
664 diffsl_real_type: RealType,
665 real_type: FloatType<'ctx>,
666 int_type: IntType<'ctx>,
667 threaded: bool,
668 ) -> Result<Self> {
669 let builder = context.create_builder();
670 let layout = DataLayout::new(model);
671 let module = context.create_module(model.name());
672 let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
673 let real_ptr_type = Self::pointer_type(context, real_type.into());
674 let int_ptr_type = Self::pointer_type(context, int_type.into());
675 let mut ret = Self {
676 context,
677 module,
678 builder,
679 real_type,
680 real_ptr_type,
681 variables: HashMap::new(),
682 functions: HashMap::new(),
683 fn_value_opt: None,
684 tensor_ptr_opt: None,
685 layout,
686 diffsl_real_type,
687 int_type,
688 int_ptr_type,
689 globals,
690 threaded,
691 _ee: None,
692 };
693 if threaded {
694 ret.compile_barrier_init()?;
695 ret.compile_barrier()?;
696 ret.compile_barrier_grad()?;
697 }
700 Ok(ret)
701 }
702
703 #[allow(dead_code)]
704 fn compile_print_value(
705 &mut self,
706 name: &str,
707 value: PrintValue<'ctx>,
708 ) -> Result<CallSiteValue<'_>> {
709 let void_type = self.context.void_type();
710 let printf_type = void_type.fn_type(&[self.int_ptr_type.into()], true);
712 let printf = match self.module.get_function("printf") {
714 Some(f) => f,
715 None => self
716 .module
717 .add_function("printf", printf_type, Some(Linkage::External)),
718 };
719 let (format_str, format_str_name) = match value {
720 PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
721 PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
722 };
723 let format_str = CString::new(format_str).unwrap();
725 let format_str_global = match self.module.get_global(format_str_name.as_str()) {
727 Some(g) => g,
728 None => {
729 let format_str = self.context.const_string(format_str.as_bytes(), true);
730 let fmt_str =
731 self.module
732 .add_global(format_str.get_type(), None, format_str_name.as_str());
733 fmt_str.set_initializer(&format_str);
734 fmt_str.set_visibility(GlobalVisibility::Hidden);
735 fmt_str
736 }
737 };
738 let format_str_ptr = self.builder.build_pointer_cast(
740 format_str_global.as_pointer_value(),
741 self.int_ptr_type,
742 "format_str_ptr",
743 )?;
744 let value: BasicMetadataValueEnum = match value {
745 PrintValue::Real(v) => v.into(),
746 PrintValue::Int(v) => v.into(),
747 };
748 self.builder
749 .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
750 .map_err(|e| anyhow!("Error building call to printf: {}", e))
751 }
752
753 fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
754 self.clear();
755 let void_type = self.context.void_type();
756 let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false);
757 let fn_arg_names = &["thread_id", "thread_dim"];
758 let function = self.module.add_function("set_constants", fn_type, None);
759
760 let basic_block = self.context.append_basic_block(function, "entry");
761 self.fn_value_opt = Some(function);
762 self.builder.position_at_end(basic_block);
763
764 for (i, arg) in function.get_param_iter().enumerate() {
765 let name = fn_arg_names[i];
766 let alloca = self.function_arg_alloca(name, arg);
767 self.insert_param(name, alloca);
768 }
769
770 self.insert_indices();
771 self.insert_constants(model);
772
773 let mut nbarriers = 0;
774 let total_barriers = (model.constant_defns().len()) as u64;
775 #[allow(clippy::explicit_counter_loop)]
776 for a in model.constant_defns() {
777 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
778 self.jit_compile_call_barrier(nbarriers, total_barriers);
779 nbarriers += 1;
780 }
781
782 self.builder.build_return(None)?;
783
784 if function.verify(true) {
785 Ok(function)
786 } else {
787 function.print_to_stderr();
788 unsafe {
789 function.delete();
790 }
791 Err(anyhow!("Invalid generated function."))
792 }
793 }
794
795 fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
796 self.clear();
797 let void_type = self.context.void_type();
798 let fn_type = void_type.fn_type(&[], false);
799 let function = self.module.add_function("barrier_init", fn_type, None);
800
801 let entry_block = self.context.append_basic_block(function, "entry");
802
803 self.fn_value_opt = Some(function);
804 self.builder.position_at_end(entry_block);
805
806 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
807 self.builder
808 .build_store(thread_counter, self.int_type.const_zero())?;
809
810 self.builder.build_return(None)?;
811
812 if function.verify(true) {
813 self.functions.insert("barrier_init".to_owned(), function);
814 Ok(function)
815 } else {
816 function.print_to_stderr();
817 unsafe {
818 function.delete();
819 }
820 Err(anyhow!("Invalid generated function."))
821 }
822 }
823
824 fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
825 self.clear();
826 let void_type = self.context.void_type();
827 let fn_type = void_type.fn_type(
828 &[
829 self.int_type.into(),
830 self.int_type.into(),
831 self.int_type.into(),
832 ],
833 false,
834 );
835 let function = self.module.add_function("barrier", fn_type, None);
836 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
837 let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
838 function.add_attribute(AttributeLoc::Function, noinline);
839
840 let entry_block = self.context.append_basic_block(function, "entry");
841 let increment_block = self.context.append_basic_block(function, "increment");
842 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
843 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
844
845 self.fn_value_opt = Some(function);
846 self.builder.position_at_end(entry_block);
847
848 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
849 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
850 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
851 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
852
853 let nbarrier_equals_total_barriers = self
854 .builder
855 .build_int_compare(
856 IntPredicate::EQ,
857 barrier_num,
858 total_barriers,
859 "nbarrier_equals_total_barriers",
860 )
861 .unwrap();
862 self.builder.build_conditional_branch(
864 nbarrier_equals_total_barriers,
865 barrier_done_block,
866 increment_block,
867 )?;
868 self.builder.position_at_end(increment_block);
869
870 let barrier_num_times_thread_count = self
871 .builder
872 .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
873 .unwrap();
874
875 let i32_type = self.context.i32_type();
877 let one = i32_type.const_int(1, false);
878 self.builder.build_atomicrmw(
879 AtomicRMWBinOp::Add,
880 thread_counter,
881 one,
882 AtomicOrdering::Monotonic,
883 )?;
884
885 self.builder.build_unconditional_branch(wait_loop_block)?;
887 self.builder.position_at_end(wait_loop_block);
888
889 let current_value = self
890 .builder
891 .build_load(i32_type, thread_counter, "current_value")?
892 .into_int_value();
893
894 current_value
895 .as_instruction_value()
896 .unwrap()
897 .set_atomic_ordering(AtomicOrdering::Monotonic)
898 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
899
900 let all_threads_done = self.builder.build_int_compare(
901 IntPredicate::UGE,
902 current_value,
903 barrier_num_times_thread_count,
904 "all_threads_done",
905 )?;
906
907 self.builder.build_conditional_branch(
908 all_threads_done,
909 barrier_done_block,
910 wait_loop_block,
911 )?;
912 self.builder.position_at_end(barrier_done_block);
913
914 self.builder.build_return(None)?;
915
916 if function.verify(true) {
917 self.functions.insert("barrier".to_owned(), function);
918 Ok(function)
919 } else {
920 function.print_to_stderr();
921 unsafe {
922 function.delete();
923 }
924 Err(anyhow!("Invalid generated function."))
925 }
926 }
927
928 fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
929 self.clear();
930 let void_type = self.context.void_type();
931 let fn_type = void_type.fn_type(
932 &[
933 self.int_type.into(),
934 self.int_type.into(),
935 self.int_type.into(),
936 ],
937 false,
938 );
939 let function = self.module.add_function("barrier_grad", fn_type, None);
940
941 let entry_block = self.context.append_basic_block(function, "entry");
942 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
943 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
944
945 self.fn_value_opt = Some(function);
946 self.builder.position_at_end(entry_block);
947
948 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
949 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
950 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
951 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
952
953 let twice_total_barriers = self
954 .builder
955 .build_int_mul(
956 total_barriers,
957 self.int_type.const_int(2, false),
958 "twice_total_barriers",
959 )
960 .unwrap();
961 let twice_total_barriers_minus_barrier_num = self
962 .builder
963 .build_int_sub(
964 twice_total_barriers,
965 barrier_num,
966 "twice_total_barriers_minus_barrier_num",
967 )
968 .unwrap();
969 let twice_total_barriers_minus_barrier_num_times_thread_count = self
970 .builder
971 .build_int_mul(
972 twice_total_barriers_minus_barrier_num,
973 thread_count,
974 "twice_total_barriers_minus_barrier_num_times_thread_count",
975 )
976 .unwrap();
977
978 let i32_type = self.context.i32_type();
980 let one = i32_type.const_int(1, false);
981 self.builder.build_atomicrmw(
982 AtomicRMWBinOp::Add,
983 thread_counter,
984 one,
985 AtomicOrdering::Monotonic,
986 )?;
987
988 self.builder.build_unconditional_branch(wait_loop_block)?;
990 self.builder.position_at_end(wait_loop_block);
991
992 let current_value = self
993 .builder
994 .build_load(i32_type, thread_counter, "current_value")?
995 .into_int_value();
996 current_value
997 .as_instruction_value()
998 .unwrap()
999 .set_atomic_ordering(AtomicOrdering::Monotonic)
1000 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1001
1002 let all_threads_done = self.builder.build_int_compare(
1003 IntPredicate::UGE,
1004 current_value,
1005 twice_total_barriers_minus_barrier_num_times_thread_count,
1006 "all_threads_done",
1007 )?;
1008
1009 self.builder.build_conditional_branch(
1010 all_threads_done,
1011 barrier_done_block,
1012 wait_loop_block,
1013 )?;
1014 self.builder.position_at_end(barrier_done_block);
1015
1016 self.builder.build_return(None)?;
1017
1018 if function.verify(true) {
1019 self.functions.insert("barrier_grad".to_owned(), function);
1020 Ok(function)
1021 } else {
1022 function.print_to_stderr();
1023 unsafe {
1024 function.delete();
1025 }
1026 Err(anyhow!("Invalid generated function."))
1027 }
1028 }
1029
1030 fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1031 if !self.threaded {
1032 return;
1033 }
1034 let thread_dim = self.get_param("thread_dim");
1035 let thread_dim = self
1036 .builder
1037 .build_load(self.int_type, *thread_dim, "thread_dim")
1038 .unwrap()
1039 .into_int_value();
1040 let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1041 let total_barriers = self.int_type.const_int(total_barriers, false);
1042 let barrier = self.get_function("barrier").unwrap();
1043 self.builder
1044 .build_call(
1045 barrier,
1046 &[
1047 BasicMetadataValueEnum::IntValue(nbarrier),
1048 BasicMetadataValueEnum::IntValue(total_barriers),
1049 BasicMetadataValueEnum::IntValue(thread_dim),
1050 ],
1051 "barrier",
1052 )
1053 .unwrap();
1054 }
1055
1056 fn jit_threading_limits(
1057 &mut self,
1058 size: IntValue<'ctx>,
1059 ) -> Result<(
1060 IntValue<'ctx>,
1061 IntValue<'ctx>,
1062 BasicBlock<'ctx>,
1063 BasicBlock<'ctx>,
1064 )> {
1065 let one = self.int_type.const_int(1, false);
1066 let thread_id = self.get_param("thread_id");
1067 let thread_id = self
1068 .builder
1069 .build_load(self.int_type, *thread_id, "thread_id")
1070 .unwrap()
1071 .into_int_value();
1072 let thread_dim = self.get_param("thread_dim");
1073 let thread_dim = self
1074 .builder
1075 .build_load(self.int_type, *thread_dim, "thread_dim")
1076 .unwrap()
1077 .into_int_value();
1078
1079 let i_times_size = self
1081 .builder
1082 .build_int_mul(thread_id, size, "i_times_size")?;
1083 let start = self
1084 .builder
1085 .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1086
1087 let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1089 let i_plus_one_times_size =
1090 self.builder
1091 .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1092 let end = self
1093 .builder
1094 .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1095
1096 let test_done = self.builder.get_insert_block().unwrap();
1097 let next_block = self
1098 .context
1099 .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1100 self.builder.position_at_end(next_block);
1101
1102 Ok((start, end, test_done, next_block))
1103 }
1104
1105 fn jit_end_threading(
1106 &mut self,
1107 start: IntValue<'ctx>,
1108 end: IntValue<'ctx>,
1109 test_done: BasicBlock<'ctx>,
1110 next: BasicBlock<'ctx>,
1111 ) -> Result<()> {
1112 let exit = self
1113 .context
1114 .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1115 self.builder.build_unconditional_branch(exit)?;
1116 self.builder.position_at_end(test_done);
1117 let done = self
1119 .builder
1120 .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1121 self.builder.build_conditional_branch(done, exit, next)?;
1122 self.builder.position_at_end(exit);
1123 Ok(())
1124 }
1125
1126 pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1127 self.module.write_bitcode_to_path(path);
1128 }
1129
1130 fn insert_constants(&mut self, model: &DiscreteModel) {
1131 if let Some(constants) = self.globals.constants.as_ref() {
1132 self.insert_param("constants", constants.as_pointer_value());
1133 for tensor in model.constant_defns() {
1134 self.insert_tensor(tensor, true);
1135 }
1136 }
1137 }
1138
1139 fn insert_data(&mut self, model: &DiscreteModel) {
1140 self.insert_constants(model);
1141
1142 if let Some(input) = model.input() {
1143 self.insert_tensor(input, false);
1144 }
1145 for tensor in model.input_dep_defns() {
1146 self.insert_tensor(tensor, false);
1147 }
1148 for tensor in model.time_dep_defns() {
1149 self.insert_tensor(tensor, false);
1150 }
1151 for tensor in model.state_dep_defns() {
1152 self.insert_tensor(tensor, false);
1153 }
1154 }
1155
1156 fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1157 context.ptr_type(AddressSpace::default())
1158 }
1159
1160 fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1161 context.ptr_type(AddressSpace::default())
1162 }
1163
1164 fn insert_indices(&mut self) {
1165 if let Some(indices) = self.globals.indices.as_ref() {
1166 let i32_type = self.context.i32_type();
1167 let zero = i32_type.const_int(0, false);
1168 let ptr = unsafe {
1169 indices
1170 .as_pointer_value()
1171 .const_in_bounds_gep(i32_type, &[zero])
1172 };
1173 self.variables.insert("indices".to_owned(), ptr);
1174 }
1175 }
1176
1177 fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1178 self.variables.insert(name.to_owned(), value);
1179 }
1180
1181 fn build_gep<T: BasicType<'ctx>>(
1182 &self,
1183 ty: T,
1184 ptr: PointerValue<'ctx>,
1185 ordered_indexes: &[IntValue<'ctx>],
1186 name: &str,
1187 ) -> Result<PointerValue<'ctx>> {
1188 unsafe {
1189 self.builder
1190 .build_gep(ty, ptr, ordered_indexes, name)
1191 .map_err(|e| e.into())
1192 }
1193 }
1194
1195 fn build_load<T: BasicType<'ctx>>(
1196 &self,
1197 ty: T,
1198 ptr: PointerValue<'ctx>,
1199 name: &str,
1200 ) -> Result<BasicValueEnum<'ctx>> {
1201 self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1202 }
1203
1204 fn get_ptr_to_index<T: BasicType<'ctx>>(
1205 builder: &Builder<'ctx>,
1206 ty: T,
1207 ptr: &PointerValue<'ctx>,
1208 index: IntValue<'ctx>,
1209 name: &str,
1210 ) -> PointerValue<'ctx> {
1211 unsafe {
1212 builder
1213 .build_in_bounds_gep(ty, *ptr, &[index], name)
1214 .unwrap()
1215 }
1216 }
1217
1218 fn insert_state(&mut self, u: &Tensor) {
1219 let mut data_index = 0;
1220 for blk in u.elmts() {
1221 if let Some(name) = blk.name() {
1222 let ptr = self.variables.get("u").unwrap();
1223 let i = self
1224 .context
1225 .i32_type()
1226 .const_int(data_index.try_into().unwrap(), false);
1227 let alloca = Self::get_ptr_to_index(
1228 &self.create_entry_block_builder(),
1229 self.real_type,
1230 ptr,
1231 i,
1232 blk.name().unwrap(),
1233 );
1234 self.variables.insert(name.to_owned(), alloca);
1235 }
1236 data_index += blk.nnz();
1237 }
1238 }
1239 fn insert_dot_state(&mut self, dudt: &Tensor) {
1240 let mut data_index = 0;
1241 for blk in dudt.elmts() {
1242 if let Some(name) = blk.name() {
1243 let ptr = self.variables.get("dudt").unwrap();
1244 let i = self
1245 .context
1246 .i32_type()
1247 .const_int(data_index.try_into().unwrap(), false);
1248 let alloca = Self::get_ptr_to_index(
1249 &self.create_entry_block_builder(),
1250 self.real_type,
1251 ptr,
1252 i,
1253 blk.name().unwrap(),
1254 );
1255 self.variables.insert(name.to_owned(), alloca);
1256 }
1257 data_index += blk.nnz();
1258 }
1259 }
1260 fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1261 let var_name = if is_constant { "constants" } else { "data" };
1262 let ptr = *self.variables.get(var_name).unwrap();
1263 let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1264 let i = self
1265 .context
1266 .i32_type()
1267 .const_int(data_index.try_into().unwrap(), false);
1268 let alloca = Self::get_ptr_to_index(
1269 &self.create_entry_block_builder(),
1270 self.real_type,
1271 &ptr,
1272 i,
1273 tensor.name(),
1274 );
1275 self.variables.insert(tensor.name().to_owned(), alloca);
1276
1277 for blk in tensor.elmts() {
1279 if let Some(name) = blk.name() {
1280 let i = self
1281 .context
1282 .i32_type()
1283 .const_int(data_index.try_into().unwrap(), false);
1284 let alloca = Self::get_ptr_to_index(
1285 &self.create_entry_block_builder(),
1286 self.real_type,
1287 &ptr,
1288 i,
1289 name,
1290 );
1291 self.variables.insert(name.to_owned(), alloca);
1292 }
1293 data_index += blk.nnz();
1295 }
1296 }
1297 fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1298 self.variables.get(name).unwrap()
1299 }
1300
1301 fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1302 self.variables.get(tensor.name()).unwrap()
1303 }
1304
1305 fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1306 match self.functions.get(name) {
1307 Some(&func) => Some(func),
1308 None => {
1309 let function = match name {
1310 "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs"
1312 | "copysign" | "pow" | "min" | "max" => {
1313 let arg_len = 1;
1314 let intrinsic_name = match name {
1315 "min" => "minnum",
1316 "max" => "maxnum",
1317 "abs" => "fabs",
1318 _ => name,
1319 };
1320 let llvm_name =
1321 format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
1322 let intrinsic = Intrinsic::find(&llvm_name).unwrap();
1323 let ret_type = self.real_type;
1324
1325 let args_types = std::iter::repeat_n(ret_type, arg_len)
1326 .map(|f| f.into())
1327 .collect::<Vec<BasicTypeEnum>>();
1328 return intrinsic.get_declaration(&self.module, args_types.as_slice());
1330 }
1331 "sigmoid" => {
1333 let arg_len = 1;
1334 let ret_type = self.real_type;
1335
1336 let args_types = std::iter::repeat_n(ret_type, arg_len)
1337 .map(|f| f.into())
1338 .collect::<Vec<BasicMetadataTypeEnum>>();
1339
1340 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1341 let fn_val = self.module.add_function(name, fn_type, None);
1342
1343 for arg in fn_val.get_param_iter() {
1344 arg.into_float_value().set_name("x");
1345 }
1346
1347 let current_block = self.builder.get_insert_block().unwrap();
1348 let basic_block = self.context.append_basic_block(fn_val, "entry");
1349 self.builder.position_at_end(basic_block);
1350 let x = fn_val.get_nth_param(0)?.into_float_value();
1351 let one = self.real_type.const_float(1.0);
1352 let negx = self.builder.build_float_neg(x, name).ok()?;
1353 let exp = self.get_function("exp").unwrap();
1354 let exp_negx = self
1355 .builder
1356 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1357 .ok()?;
1358 let one_plus_exp_negx = self
1359 .builder
1360 .build_float_add(
1361 exp_negx
1362 .try_as_basic_value()
1363 .unwrap_basic()
1364 .into_float_value(),
1365 one,
1366 name,
1367 )
1368 .ok()?;
1369 let sigmoid = self
1370 .builder
1371 .build_float_div(one, one_plus_exp_negx, name)
1372 .ok()?;
1373 self.builder.build_return(Some(&sigmoid)).ok();
1374 self.builder.position_at_end(current_block);
1375 Some(fn_val)
1376 }
1377 "arcsinh" | "arccosh" => {
1378 let arg_len = 1;
1379 let ret_type = self.real_type;
1380
1381 let args_types = std::iter::repeat_n(ret_type, arg_len)
1382 .map(|f| f.into())
1383 .collect::<Vec<BasicMetadataTypeEnum>>();
1384
1385 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1386 let fn_val = self.module.add_function(name, fn_type, None);
1387
1388 for arg in fn_val.get_param_iter() {
1389 arg.into_float_value().set_name("x");
1390 }
1391
1392 let current_block = self.builder.get_insert_block().unwrap();
1393 let basic_block = self.context.append_basic_block(fn_val, "entry");
1394 self.builder.position_at_end(basic_block);
1395 let x = fn_val.get_nth_param(0)?.into_float_value();
1396 let one = match name {
1397 "arccosh" => self.real_type.const_float(-1.0),
1398 "arcsinh" => self.real_type.const_float(1.0),
1399 _ => panic!("unknown function"),
1400 };
1401 let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1402 let one_plus_x_squared =
1403 self.builder.build_float_add(x_squared, one, name).ok()?;
1404 let sqrt = self.get_function("sqrt").unwrap();
1405 let sqrt_one_plus_x_squared = self
1406 .builder
1407 .build_call(
1408 sqrt,
1409 &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1410 name,
1411 )
1412 .unwrap()
1413 .try_as_basic_value()
1414 .unwrap_basic()
1415 .into_float_value();
1416 let x_plus_sqrt_one_plus_x_squared = self
1417 .builder
1418 .build_float_add(x, sqrt_one_plus_x_squared, name)
1419 .ok()?;
1420 let ln = self.get_function("log").unwrap();
1421 let result = self
1422 .builder
1423 .build_call(
1424 ln,
1425 &[BasicMetadataValueEnum::FloatValue(
1426 x_plus_sqrt_one_plus_x_squared,
1427 )],
1428 name,
1429 )
1430 .unwrap()
1431 .try_as_basic_value()
1432 .unwrap_basic()
1433 .into_float_value();
1434 self.builder.build_return(Some(&result)).ok();
1435 self.builder.position_at_end(current_block);
1436 Some(fn_val)
1437 }
1438 "heaviside" => {
1439 let arg_len = 1;
1440 let ret_type = self.real_type;
1441
1442 let args_types = std::iter::repeat_n(ret_type, arg_len)
1443 .map(|f| f.into())
1444 .collect::<Vec<BasicMetadataTypeEnum>>();
1445
1446 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1447 let fn_val = self.module.add_function(name, fn_type, None);
1448
1449 for arg in fn_val.get_param_iter() {
1450 arg.into_float_value().set_name("x");
1451 }
1452
1453 let current_block = self.builder.get_insert_block().unwrap();
1454 let basic_block = self.context.append_basic_block(fn_val, "entry");
1455 self.builder.position_at_end(basic_block);
1456 let x = fn_val.get_nth_param(0)?.into_float_value();
1457 let zero = self.real_type.const_float(0.0);
1458 let one = self.real_type.const_float(1.0);
1459 let result = self
1460 .builder
1461 .build_select(
1462 self.builder
1463 .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1464 .unwrap(),
1465 one,
1466 zero,
1467 name,
1468 )
1469 .ok()?;
1470 self.builder.build_return(Some(&result)).ok();
1471 self.builder.position_at_end(current_block);
1472 Some(fn_val)
1473 }
1474 "tanh" | "sinh" | "cosh" => {
1475 let arg_len = 1;
1476 let ret_type = self.real_type;
1477
1478 let args_types = std::iter::repeat_n(ret_type, arg_len)
1479 .map(|f| f.into())
1480 .collect::<Vec<BasicMetadataTypeEnum>>();
1481
1482 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1483 let fn_val = self.module.add_function(name, fn_type, None);
1484
1485 for arg in fn_val.get_param_iter() {
1486 arg.into_float_value().set_name("x");
1487 }
1488
1489 let current_block = self.builder.get_insert_block().unwrap();
1490 let basic_block = self.context.append_basic_block(fn_val, "entry");
1491 self.builder.position_at_end(basic_block);
1492 let x = fn_val.get_nth_param(0)?.into_float_value();
1493 let negx = self.builder.build_float_neg(x, name).ok()?;
1494 let exp = self.get_function("exp").unwrap();
1495 let exp_negx = self
1496 .builder
1497 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1498 .ok()?;
1499 let expx = self
1500 .builder
1501 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1502 .ok()?;
1503 let expx_minus_exp_negx = self
1504 .builder
1505 .build_float_sub(
1506 expx.try_as_basic_value().unwrap_basic().into_float_value(),
1507 exp_negx
1508 .try_as_basic_value()
1509 .unwrap_basic()
1510 .into_float_value(),
1511 name,
1512 )
1513 .ok()?;
1514 let expx_plus_exp_negx = self
1515 .builder
1516 .build_float_add(
1517 expx.try_as_basic_value().unwrap_basic().into_float_value(),
1518 exp_negx
1519 .try_as_basic_value()
1520 .unwrap_basic()
1521 .into_float_value(),
1522 name,
1523 )
1524 .ok()?;
1525 let result = match name {
1526 "tanh" => self
1527 .builder
1528 .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1529 .ok()?,
1530 "sinh" => self
1531 .builder
1532 .build_float_div(
1533 expx_minus_exp_negx,
1534 self.real_type.const_float(2.0),
1535 name,
1536 )
1537 .ok()?,
1538 "cosh" => self
1539 .builder
1540 .build_float_div(
1541 expx_plus_exp_negx,
1542 self.real_type.const_float(2.0),
1543 name,
1544 )
1545 .ok()?,
1546 _ => panic!("unknown function"),
1547 };
1548 self.builder.build_return(Some(&result)).ok();
1549 self.builder.position_at_end(current_block);
1550 Some(fn_val)
1551 }
1552 _ => None,
1553 }?;
1554 self.functions.insert(name.to_owned(), function);
1555 Some(function)
1556 }
1557 }
1558 }
1559 #[inline]
1561 fn fn_value(&self) -> FunctionValue<'ctx> {
1562 self.fn_value_opt.unwrap()
1563 }
1564
1565 #[inline]
1566 fn tensor_ptr(&self) -> PointerValue<'ctx> {
1567 self.tensor_ptr_opt.unwrap()
1568 }
1569
1570 fn create_entry_block_builder(&self) -> Builder<'ctx> {
1572 let builder = self.context.create_builder();
1573 let entry = self.fn_value().get_first_basic_block().unwrap();
1574 match entry.get_first_instruction() {
1575 Some(first_instr) => builder.position_before(&first_instr),
1576 None => builder.position_at_end(entry),
1577 }
1578 builder
1579 }
1580
1581 fn jit_compile_scalar(
1582 &mut self,
1583 a: &Tensor,
1584 res_ptr_opt: Option<PointerValue<'ctx>>,
1585 ) -> Result<PointerValue<'ctx>> {
1586 let res_type = self.real_type;
1587 let res_ptr = match res_ptr_opt {
1588 Some(ptr) => ptr,
1589 None => self
1590 .create_entry_block_builder()
1591 .build_alloca(res_type, a.name())?,
1592 };
1593 let name = a.name();
1594 let elmt = a.elmts().first().unwrap();
1595
1596 let curr_block = self.builder.get_insert_block().unwrap();
1598 let mut next_block_opt = None;
1599 if self.threaded {
1600 let next_block = self.context.append_basic_block(self.fn_value(), "next");
1601 self.builder.position_at_end(next_block);
1602 next_block_opt = Some(next_block);
1603 }
1604
1605 let zero = self.int_type.const_zero();
1606 let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
1607 self.builder.build_store(res_ptr, float_value)?;
1608
1609 if self.threaded {
1611 let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1612 self.builder.build_unconditional_branch(exit_block)?;
1613 self.builder.position_at_end(curr_block);
1614
1615 let thread_id = self.get_param("thread_id");
1616 let thread_id = self
1617 .builder
1618 .build_load(self.int_type, *thread_id, "thread_id")
1619 .unwrap()
1620 .into_int_value();
1621 let is_first_thread = self.builder.build_int_compare(
1622 IntPredicate::EQ,
1623 thread_id,
1624 self.int_type.const_zero(),
1625 "is_first_thread",
1626 )?;
1627 self.builder.build_conditional_branch(
1628 is_first_thread,
1629 next_block_opt.unwrap(),
1630 exit_block,
1631 )?;
1632
1633 self.builder.position_at_end(exit_block);
1634 }
1635
1636 Ok(res_ptr)
1637 }
1638
1639 fn jit_compile_tensor(
1640 &mut self,
1641 a: &Tensor,
1642 res_ptr_opt: Option<PointerValue<'ctx>>,
1643 ) -> Result<PointerValue<'ctx>> {
1644 if a.rank() == 0 {
1646 return self.jit_compile_scalar(a, res_ptr_opt);
1647 }
1648
1649 let res_type = self.real_type;
1650 let res_ptr = match res_ptr_opt {
1651 Some(ptr) => ptr,
1652 None => self
1653 .create_entry_block_builder()
1654 .build_alloca(res_type, a.name())?,
1655 };
1656
1657 self.tensor_ptr_opt = Some(res_ptr);
1659
1660 for (i, blk) in a.elmts().iter().enumerate() {
1661 let default = format!("{}-{}", a.name(), i);
1662 let name = blk.name().unwrap_or(default.as_str());
1663 self.jit_compile_block(name, a, blk)?;
1664 }
1665 Ok(res_ptr)
1666 }
1667
1668 fn jit_compile_block(&mut self, name: &str, tensor: &Tensor, elmt: &TensorBlock) -> Result<()> {
1669 let translation = Translation::new(
1670 elmt.expr_layout(),
1671 elmt.layout(),
1672 elmt.start(),
1673 tensor.layout_ptr(),
1674 );
1675
1676 if elmt.expr_layout().is_dense() {
1677 self.jit_compile_dense_block(name, elmt, &translation)
1678 } else if elmt.expr_layout().is_diagonal() {
1679 self.jit_compile_diagonal_block(name, elmt, &translation)
1680 } else if elmt.expr_layout().is_sparse() {
1681 match translation.source {
1682 TranslationFrom::SparseContraction { .. } => {
1683 self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1684 }
1685 _ => self.jit_compile_sparse_block(name, elmt, &translation),
1686 }
1687 } else {
1688 Err(anyhow!(
1689 "unsupported block layout: {:?}",
1690 elmt.expr_layout()
1691 ))
1692 }
1693 }
1694
1695 fn jit_compile_dense_block(
1697 &mut self,
1698 name: &str,
1699 elmt: &TensorBlock,
1700 translation: &Translation,
1701 ) -> Result<()> {
1702 let int_type = self.int_type;
1703
1704 let mut preblock = self.builder.get_insert_block().unwrap();
1705 let expr_rank = elmt.expr_layout().rank();
1706 let expr_shape = elmt
1707 .expr_layout()
1708 .shape()
1709 .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1710 let one = int_type.const_int(1, false);
1711
1712 let mut expr_strides = vec![1; expr_rank];
1713 if expr_rank > 0 {
1714 for i in (0..expr_rank - 1).rev() {
1715 expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1716 }
1717 }
1718 let expr_strides = expr_strides
1719 .iter()
1720 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1721 .collect::<Vec<IntValue>>();
1722
1723 let mut indices = Vec::new();
1725 let mut blocks = Vec::new();
1726
1727 let (contract_sum, contract_by, contract_strides) =
1729 if let TranslationFrom::DenseContraction {
1730 contract_by,
1731 contract_len: _,
1732 } = translation.source
1733 {
1734 let contract_rank = expr_rank - contract_by;
1735 let mut contract_strides = vec![1; contract_rank];
1736 for i in (0..contract_rank - 1).rev() {
1737 contract_strides[i] =
1738 contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1739 }
1740 let contract_strides = contract_strides
1741 .iter()
1742 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1743 .collect::<Vec<IntValue>>();
1744 (
1745 Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1746 contract_by,
1747 Some(contract_strides),
1748 )
1749 } else {
1750 (None, 0, None)
1751 };
1752
1753 let (thread_start, thread_end, test_done, next) = if self.threaded {
1755 let (start, end, test_done, next) =
1756 self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1757 preblock = next;
1758 (Some(start), Some(end), Some(test_done), Some(next))
1759 } else {
1760 (None, None, None, None)
1761 };
1762
1763 for i in 0..expr_rank {
1764 let block = self.context.append_basic_block(self.fn_value(), name);
1765 self.builder.build_unconditional_branch(block)?;
1766 self.builder.position_at_end(block);
1767
1768 let start_index = if i == 0 && self.threaded {
1769 thread_start.unwrap()
1770 } else {
1771 self.int_type.const_zero()
1772 };
1773
1774 let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1775 curr_index.add_incoming(&[(&start_index, preblock)]);
1776
1777 if i == expr_rank - contract_by - 1 {
1778 if let Some(contract_sum) = contract_sum {
1779 self.builder
1780 .build_store(contract_sum, self.real_type.const_zero())?;
1781 }
1782 }
1783
1784 indices.push(curr_index);
1785 blocks.push(block);
1786 preblock = block;
1787 }
1788
1789 let indices_int: Vec<IntValue> = indices
1790 .iter()
1791 .map(|i| i.as_basic_value().into_int_value())
1792 .collect();
1793
1794 let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
1796 let mut stride = 1u64;
1797 if !indices.is_empty() {
1798 for i in (0..indices.len() - 1).rev() {
1799 let iname_i = indices_int[i];
1800 let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
1801 stride *= shapei;
1802 let stride_intval = self.context.i32_type().const_int(stride, false);
1803 let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
1804 expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
1805 }
1806 }
1807
1808 let float_value =
1809 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
1810
1811 if let Some(contract_sum) = contract_sum {
1812 let contract_sum_value = self
1813 .build_load(self.real_type, contract_sum, "contract_sum")?
1814 .into_float_value();
1815 let new_contract_sum_value = self.builder.build_float_add(
1816 contract_sum_value,
1817 float_value,
1818 "new_contract_sum",
1819 )?;
1820 self.builder
1821 .build_store(contract_sum, new_contract_sum_value)?;
1822 } else {
1823 let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
1824 self.int_type.const_zero(),
1825 |acc, (i, s)| {
1826 let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
1827 self.builder.build_int_add(acc, tmp, "acc").unwrap()
1828 },
1829 );
1830 self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
1831 }
1832
1833 let mut postblock = self.builder.get_insert_block().unwrap();
1834
1835 for i in (0..expr_rank).rev() {
1837 let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
1839 indices[i].add_incoming(&[(&next_index, postblock)]);
1840 if i == expr_rank - contract_by - 1 {
1841 if let Some(contract_sum) = contract_sum {
1842 let contract_sum_value = self
1843 .build_load(self.real_type, contract_sum, "contract_sum")?
1844 .into_float_value();
1845 let contract_strides = contract_strides.as_ref().unwrap();
1846 let elmt_index = indices_int
1847 .iter()
1848 .take(contract_strides.len())
1849 .zip(contract_strides.iter())
1850 .fold(self.int_type.const_zero(), |acc, (i, s)| {
1851 let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
1852 self.builder.build_int_add(acc, tmp, "acc").unwrap()
1853 });
1854 self.jit_compile_store(
1855 name,
1856 elmt,
1857 elmt_index,
1858 contract_sum_value,
1859 translation,
1860 )?;
1861 }
1862 }
1863
1864 let end_index = if i == 0 && self.threaded {
1865 thread_end.unwrap()
1866 } else {
1867 expr_shape[i]
1868 };
1869
1870 let loop_while =
1872 self.builder
1873 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
1874 let block = self.context.append_basic_block(self.fn_value(), name);
1875 self.builder
1876 .build_conditional_branch(loop_while, blocks[i], block)?;
1877 self.builder.position_at_end(block);
1878 postblock = block;
1879 }
1880
1881 if self.threaded {
1882 self.jit_end_threading(
1883 thread_start.unwrap(),
1884 thread_end.unwrap(),
1885 test_done.unwrap(),
1886 next.unwrap(),
1887 )?;
1888 }
1889 Ok(())
1890 }
1891
1892 fn jit_compile_sparse_contraction_block(
1893 &mut self,
1894 name: &str,
1895 elmt: &TensorBlock,
1896 translation: &Translation,
1897 ) -> Result<()> {
1898 match translation.source {
1899 TranslationFrom::SparseContraction { .. } => {}
1900 _ => {
1901 panic!("expected sparse contraction")
1902 }
1903 }
1904 let int_type = self.int_type;
1905
1906 let translation_index = self
1907 .layout
1908 .get_translation_index(elmt.expr_layout(), elmt.layout())
1909 .unwrap();
1910 let translation_index = translation_index + translation.get_from_index_in_data_layout();
1911
1912 let final_contract_index =
1913 int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
1914 let (thread_start, thread_end, test_done, next) = if self.threaded {
1915 let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
1916 (Some(start), Some(end), Some(test_done), Some(next))
1917 } else {
1918 (None, None, None, None)
1919 };
1920
1921 let preblock = self.builder.get_insert_block().unwrap();
1922 let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
1923
1924 let block = self.context.append_basic_block(self.fn_value(), name);
1926 self.builder.build_unconditional_branch(block)?;
1927 self.builder.position_at_end(block);
1928
1929 let contract_index = self.builder.build_phi(int_type, "i")?;
1930 let contract_start = if self.threaded {
1931 thread_start.unwrap()
1932 } else {
1933 int_type.const_zero()
1934 };
1935 contract_index.add_incoming(&[(&contract_start, preblock)]);
1936
1937 let start_index = self.builder.build_int_add(
1938 int_type.const_int(translation_index.try_into().unwrap(), false),
1939 self.builder.build_int_mul(
1940 int_type.const_int(2, false),
1941 contract_index.as_basic_value().into_int_value(),
1942 name,
1943 )?,
1944 name,
1945 )?;
1946 let end_index =
1947 self.builder
1948 .build_int_add(start_index, int_type.const_int(1, false), name)?;
1949 let start_ptr = self.build_gep(
1950 self.int_type,
1951 *self.get_param("indices"),
1952 &[start_index],
1953 "start_index_ptr",
1954 )?;
1955 let start_contract = self
1956 .build_load(self.int_type, start_ptr, "start")?
1957 .into_int_value();
1958 let end_ptr = self.build_gep(
1959 self.int_type,
1960 *self.get_param("indices"),
1961 &[end_index],
1962 "end_index_ptr",
1963 )?;
1964 let end_contract = self
1965 .build_load(self.int_type, end_ptr, "end")?
1966 .into_int_value();
1967
1968 self.builder
1970 .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
1971
1972 let start_contract_block = self
1974 .context
1975 .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
1976 self.builder
1977 .build_unconditional_branch(start_contract_block)?;
1978 self.builder.position_at_end(start_contract_block);
1979
1980 let expr_index_phi = self.builder.build_phi(int_type, "j")?;
1981 expr_index_phi.add_incoming(&[(&start_contract, block)]);
1982
1983 let expr_index = expr_index_phi.as_basic_value().into_int_value();
1984 let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
1985
1986 let float_value =
1988 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
1989 let contract_sum_value = self
1990 .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
1991 .into_float_value();
1992 let new_contract_sum_value =
1993 self.builder
1994 .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
1995 self.builder
1996 .build_store(contract_sum_ptr, new_contract_sum_value)?;
1997
1998 let end_contract_block = self.builder.get_insert_block().unwrap();
1999
2000 let next_elmt_index =
2002 self.builder
2003 .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2004 expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
2005
2006 let loop_while = self.builder.build_int_compare(
2008 IntPredicate::ULT,
2009 next_elmt_index,
2010 end_contract,
2011 name,
2012 )?;
2013 let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2014 self.builder.build_conditional_branch(
2015 loop_while,
2016 start_contract_block,
2017 post_contract_block,
2018 )?;
2019 self.builder.position_at_end(post_contract_block);
2020
2021 self.jit_compile_store(
2023 name,
2024 elmt,
2025 contract_index.as_basic_value().into_int_value(),
2026 new_contract_sum_value,
2027 translation,
2028 )?;
2029
2030 let next_contract_index = self.builder.build_int_add(
2032 contract_index.as_basic_value().into_int_value(),
2033 int_type.const_int(1, false),
2034 name,
2035 )?;
2036 contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2037
2038 let loop_while = self.builder.build_int_compare(
2040 IntPredicate::ULT,
2041 next_contract_index,
2042 thread_end.unwrap_or(final_contract_index),
2043 name,
2044 )?;
2045 let post_block = self.context.append_basic_block(self.fn_value(), name);
2046 self.builder
2047 .build_conditional_branch(loop_while, block, post_block)?;
2048 self.builder.position_at_end(post_block);
2049
2050 if self.threaded {
2051 self.jit_end_threading(
2052 thread_start.unwrap(),
2053 thread_end.unwrap(),
2054 test_done.unwrap(),
2055 next.unwrap(),
2056 )?;
2057 }
2058
2059 Ok(())
2060 }
2061
2062 fn expr_indices_from_elmt_index(
2063 &mut self,
2064 elmt_index: IntValue<'ctx>,
2065 elmt: &TensorBlock,
2066 name: &str,
2067 ) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
2068 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2069 let int_type = self.int_type;
2070 let elmt_index_mult_rank = self.builder.build_int_mul(
2072 elmt_index,
2073 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2074 name,
2075 )?;
2076 (0..elmt.expr_layout().rank())
2077 .map(|i| {
2078 let layout_index_plus_offset =
2079 int_type.const_int((layout_index + i).try_into().unwrap(), false);
2080 let curr_index = self.builder.build_int_add(
2081 elmt_index_mult_rank,
2082 layout_index_plus_offset,
2083 name,
2084 )?;
2085 let ptr = Self::get_ptr_to_index(
2086 &self.builder,
2087 self.int_type,
2088 self.get_param("indices"),
2089 curr_index,
2090 name,
2091 );
2092 Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2093 })
2094 .collect::<Result<Vec<_>, anyhow::Error>>()
2095 }
2096
2097 fn jit_compile_sparse_block(
2100 &mut self,
2101 name: &str,
2102 elmt: &TensorBlock,
2103 translation: &Translation,
2104 ) -> Result<()> {
2105 let int_type = self.int_type;
2106
2107 let start_index = int_type.const_int(0, false);
2108 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2109
2110 let (thread_start, thread_end, test_done, next) = if self.threaded {
2111 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2112 (Some(start), Some(end), Some(test_done), Some(next))
2113 } else {
2114 (None, None, None, None)
2115 };
2116
2117 let preblock = self.builder.get_insert_block().unwrap();
2119 let loop_block = self.context.append_basic_block(self.fn_value(), name);
2120 self.builder.build_unconditional_branch(loop_block)?;
2121 self.builder.position_at_end(loop_block);
2122
2123 let curr_index = self.builder.build_phi(int_type, "i")?;
2124 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2125
2126 let elmt_index = curr_index.as_basic_value().into_int_value();
2127 let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
2128
2129 let float_value =
2131 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2132
2133 self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2134
2135 let end_loop_block = self.builder.get_insert_block().unwrap();
2137
2138 let one = int_type.const_int(1, false);
2140 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2141 curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2142
2143 let loop_while = self.builder.build_int_compare(
2145 IntPredicate::ULT,
2146 next_index,
2147 thread_end.unwrap_or(end_index),
2148 name,
2149 )?;
2150 let post_block = self.context.append_basic_block(self.fn_value(), name);
2151 self.builder
2152 .build_conditional_branch(loop_while, loop_block, post_block)?;
2153 self.builder.position_at_end(post_block);
2154
2155 if self.threaded {
2156 self.jit_end_threading(
2157 thread_start.unwrap(),
2158 thread_end.unwrap(),
2159 test_done.unwrap(),
2160 next.unwrap(),
2161 )?;
2162 }
2163
2164 Ok(())
2165 }
2166
2167 fn jit_compile_diagonal_block(
2169 &mut self,
2170 name: &str,
2171 elmt: &TensorBlock,
2172 translation: &Translation,
2173 ) -> Result<()> {
2174 let int_type = self.int_type;
2175
2176 let start_index = int_type.const_int(0, false);
2177 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2178
2179 let (thread_start, thread_end, test_done, next) = if self.threaded {
2180 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2181 (Some(start), Some(end), Some(test_done), Some(next))
2182 } else {
2183 (None, None, None, None)
2184 };
2185
2186 let preblock = self.builder.get_insert_block().unwrap();
2188 let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
2189 self.builder.build_unconditional_branch(start_loop_block)?;
2190 self.builder.position_at_end(start_loop_block);
2191
2192 let curr_index = self.builder.build_phi(int_type, "i")?;
2193 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2194
2195 let elmt_index = curr_index.as_basic_value().into_int_value();
2197 let indices_int: Vec<IntValue> =
2198 (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2199
2200 let float_value =
2202 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2203
2204 self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2206
2207 let end_loop_block = self.builder.get_insert_block().unwrap();
2208
2209 let one = int_type.const_int(1, false);
2211 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2212 curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2213
2214 let loop_while = self.builder.build_int_compare(
2216 IntPredicate::ULT,
2217 next_index,
2218 thread_end.unwrap_or(end_index),
2219 name,
2220 )?;
2221 let post_block = self.context.append_basic_block(self.fn_value(), name);
2222 self.builder
2223 .build_conditional_branch(loop_while, start_loop_block, post_block)?;
2224 self.builder.position_at_end(post_block);
2225
2226 if self.threaded {
2227 self.jit_end_threading(
2228 thread_start.unwrap(),
2229 thread_end.unwrap(),
2230 test_done.unwrap(),
2231 next.unwrap(),
2232 )?;
2233 }
2234
2235 Ok(())
2236 }
2237
2238 fn jit_compile_broadcast_and_store(
2239 &mut self,
2240 name: &str,
2241 elmt: &TensorBlock,
2242 float_value: FloatValue<'ctx>,
2243 expr_index: IntValue<'ctx>,
2244 translation: &Translation,
2245 ) -> Result<()> {
2246 let int_type = self.int_type;
2247 let one = int_type.const_int(1, false);
2248 let zero = int_type.const_int(0, false);
2249 let pre_block = self.builder.get_insert_block().unwrap();
2250 match translation.source {
2251 TranslationFrom::Broadcast {
2252 broadcast_by: _,
2253 broadcast_len,
2254 } => {
2255 let bcast_start_index = zero;
2256 let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2257
2258 let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2260 self.builder.build_unconditional_branch(bcast_block)?;
2261 self.builder.position_at_end(bcast_block);
2262 let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2263 bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2264
2265 let store_index = self.builder.build_int_add(
2267 self.builder
2268 .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2269 bcast_index.as_basic_value().into_int_value(),
2270 "bcast_store_index",
2271 )?;
2272 self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2273
2274 let bcast_next_index = self.builder.build_int_add(
2276 bcast_index.as_basic_value().into_int_value(),
2277 one,
2278 name,
2279 )?;
2280 bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2281
2282 let bcast_cond = self.builder.build_int_compare(
2284 IntPredicate::ULT,
2285 bcast_next_index,
2286 bcast_end_index,
2287 "broadcast_cond",
2288 )?;
2289 let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2290 self.builder
2291 .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2292 self.builder.position_at_end(post_bcast_block);
2293 Ok(())
2294 }
2295 TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2296 self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2297 Ok(())
2298 }
2299 _ => Err(anyhow!("Invalid translation")),
2300 }
2301 }
2302
2303 fn jit_compile_store(
2304 &mut self,
2305 name: &str,
2306 elmt: &TensorBlock,
2307 store_index: IntValue<'ctx>,
2308 float_value: FloatValue<'ctx>,
2309 translation: &Translation,
2310 ) -> Result<()> {
2311 let int_type = self.int_type;
2312 let res_index = match &translation.target {
2313 TranslationTo::Contiguous { start, end: _ } => {
2314 let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2315 self.builder.build_int_add(start_const, store_index, name)?
2316 }
2317 TranslationTo::Sparse { indices: _ } => {
2318 let translate_index = self
2320 .layout
2321 .get_translation_index(elmt.expr_layout(), elmt.layout())
2322 .unwrap();
2323 let translate_store_index =
2324 translate_index + translation.get_to_index_in_data_layout();
2325 let translate_store_index =
2326 int_type.const_int(translate_store_index.try_into().unwrap(), false);
2327 let elmt_index_strided = store_index;
2328 let curr_index =
2329 self.builder
2330 .build_int_add(elmt_index_strided, translate_store_index, name)?;
2331 let ptr = Self::get_ptr_to_index(
2332 &self.builder,
2333 self.int_type,
2334 self.get_param("indices"),
2335 curr_index,
2336 name,
2337 );
2338 self.build_load(self.int_type, ptr, name)?.into_int_value()
2339 }
2340 };
2341
2342 let resi_ptr = Self::get_ptr_to_index(
2343 &self.builder,
2344 self.real_type,
2345 &self.tensor_ptr(),
2346 res_index,
2347 name,
2348 );
2349 self.builder.build_store(resi_ptr, float_value)?;
2350 Ok(())
2351 }
2352
2353 fn jit_compile_expr(
2354 &mut self,
2355 name: &str,
2356 expr: &Ast,
2357 index: &[IntValue<'ctx>],
2358 elmt: &TensorBlock,
2359 expr_index: IntValue<'ctx>,
2360 ) -> Result<FloatValue<'ctx>> {
2361 let name = elmt.name().unwrap_or(name);
2362 match &expr.kind {
2363 AstKind::Binop(binop) => {
2364 let lhs =
2365 self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2366 let rhs =
2367 self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2368 match binop.op {
2369 '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2370 '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2371 '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2372 '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2373 unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2374 }
2375 }
2376 AstKind::Monop(monop) => {
2377 let child =
2378 self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2379 match monop.op {
2380 '-' => Ok(self.builder.build_float_neg(child, name)?),
2381 unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2382 }
2383 }
2384 AstKind::Call(call) => match self.get_function(call.fn_name) {
2385 Some(function) => {
2386 let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2387 for arg in call.args.iter() {
2388 let arg_val =
2389 self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2390 args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2391 }
2392 let ret_value = self
2393 .builder
2394 .build_call(function, args.as_slice(), name)?
2395 .try_as_basic_value()
2396 .unwrap_basic()
2397 .into_float_value();
2398 Ok(ret_value)
2399 }
2400 None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2401 },
2402 AstKind::CallArg(arg) => {
2403 self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2404 }
2405 AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2406 AstKind::Name(iname) => {
2407 let ptr = self.get_param(iname.name);
2408 let layout = self.layout.get_layout(iname.name).unwrap();
2409 let iname_elmt_index = if layout.is_dense() {
2410 let mut no_transform = true;
2412 let mut iname_index = Vec::new();
2413 for (i, c) in iname.indices.iter().enumerate() {
2414 let pi = elmt
2417 .indices()
2418 .iter()
2419 .position(|x| x == c)
2420 .unwrap_or(elmt.indices().len());
2421 if let Some(indice) =
2423 iname.indice.as_ref().map(|i| i.kind.as_indice().unwrap())
2424 {
2425 let start = indice.first.as_ref().kind.as_integer().unwrap();
2426 let start_intval = self
2427 .context
2428 .i32_type()
2429 .const_int(start.try_into().unwrap(), false);
2430 let index_pi = if pi >= index.len() {
2432 self.context.i32_type().const_int(0, false)
2433 } else {
2434 index[pi]
2435 };
2436 let index_pi =
2437 self.builder.build_int_add(index_pi, start_intval, name)?;
2438 iname_index.push(index_pi);
2439 } else {
2440 iname_index.push(index[pi]);
2441 }
2442 no_transform = no_transform && pi == i;
2443 }
2444 if !iname_index.is_empty() {
2447 let mut iname_elmt_index = *iname_index.last().unwrap();
2448 let mut stride = 1u64;
2449 for i in (0..iname_index.len() - 1).rev() {
2450 let iname_i = iname_index[i];
2451 let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2452 stride *= shapei;
2453 let stride_intval = self.context.i32_type().const_int(stride, false);
2454 let stride_mul_i =
2455 self.builder.build_int_mul(stride_intval, iname_i, name)?;
2456 iname_elmt_index =
2457 self.builder
2458 .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2459 }
2460 iname_elmt_index
2461 } else {
2462 let zero = self.context.i32_type().const_int(0, false);
2464 zero
2465 }
2466 } else if layout.is_sparse() || layout.is_diagonal() {
2467 let expr_layout = elmt.expr_layout();
2468
2469 if expr_layout != layout {
2470 let permutation =
2477 DataLayout::permutation(elmt, iname.indices.as_slice(), layout);
2478 if let Some(base_binary_layout_index) =
2479 self.layout
2480 .get_binary_layout_index(layout, expr_layout, permutation)
2481 {
2482 let binary_layout_index = self.builder.build_int_add(
2483 self.int_type
2484 .const_int(base_binary_layout_index.try_into().unwrap(), false),
2485 expr_index,
2486 name,
2487 )?;
2488
2489 let indices_ptr = Self::get_ptr_to_index(
2490 &self.builder,
2491 self.int_type,
2492 self.get_param("indices"),
2493 binary_layout_index,
2494 name,
2495 );
2496 let mapped_index = self
2497 .build_load(self.int_type, indices_ptr, name)?
2498 .into_int_value();
2499 let is_less_than_zero = self.builder.build_int_compare(
2500 IntPredicate::SLT,
2501 mapped_index,
2502 self.int_type.const_int(0, true),
2503 "sparse_index_check",
2504 )?;
2505 let is_less_than_zero_block =
2506 self.context.append_basic_block(self.fn_value(), "lt_zero");
2507 let not_less_than_zero_block = self
2508 .context
2509 .append_basic_block(self.fn_value(), "not_lt_zero");
2510 let merge_block =
2511 self.context.append_basic_block(self.fn_value(), "merge");
2512 self.builder.build_conditional_branch(
2513 is_less_than_zero,
2514 is_less_than_zero_block,
2515 not_less_than_zero_block,
2516 )?;
2517
2518 self.builder.position_at_end(is_less_than_zero_block);
2520 let zero_value = self.real_type.const_float(0.);
2521 self.builder.build_unconditional_branch(merge_block)?;
2522
2523 self.builder.position_at_end(not_less_than_zero_block);
2525 let value_ptr = Self::get_ptr_to_index(
2526 &self.builder,
2527 self.real_type,
2528 ptr,
2529 mapped_index,
2530 name,
2531 );
2532 let value = self
2533 .build_load(self.real_type, value_ptr, name)?
2534 .into_float_value();
2535 self.builder.build_unconditional_branch(merge_block)?;
2536
2537 self.builder.position_at_end(merge_block);
2539 let if_return_value =
2540 self.builder.build_phi(self.real_type, "sparse_value")?;
2541 if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
2542 if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
2543
2544 let phi_value = if_return_value.as_basic_value().into_float_value();
2545 return Ok(phi_value);
2546 } else {
2547 expr_index
2548 }
2549 } else {
2550 expr_index
2552 }
2553 } else {
2554 panic!("unexpected layout");
2555 };
2556 let value_ptr = Self::get_ptr_to_index(
2557 &self.builder,
2558 self.real_type,
2559 ptr,
2560 iname_elmt_index,
2561 name,
2562 );
2563 Ok(self
2564 .build_load(self.real_type, value_ptr, name)?
2565 .into_float_value())
2566 }
2567 AstKind::NamedGradient(name) => {
2568 let name_str = name.to_string();
2569 let ptr = self.get_param(name_str.as_str());
2570 Ok(self
2571 .build_load(self.real_type, *ptr, name_str.as_str())?
2572 .into_float_value())
2573 }
2574 AstKind::Index(_) => todo!(),
2575 AstKind::Slice(_) => todo!(),
2576 AstKind::Integer(_) => todo!(),
2577 _ => panic!("unexprected astkind"),
2578 }
2579 }
2580
2581 fn clear(&mut self) {
2582 self.variables.clear();
2583 self.fn_value_opt = None;
2585 self.tensor_ptr_opt = None;
2586 }
2587
2588 fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2589 match arg {
2590 BasicValueEnum::PointerValue(v) => v,
2591 BasicValueEnum::FloatValue(v) => {
2592 let alloca = self
2593 .create_entry_block_builder()
2594 .build_alloca(arg.get_type(), name)
2595 .unwrap();
2596 self.builder.build_store(alloca, v).unwrap();
2597 alloca
2598 }
2599 BasicValueEnum::IntValue(v) => {
2600 let alloca = self
2601 .create_entry_block_builder()
2602 .build_alloca(arg.get_type(), name)
2603 .unwrap();
2604 self.builder.build_store(alloca, v).unwrap();
2605 alloca
2606 }
2607 _ => unreachable!(),
2608 }
2609 }
2610
2611 pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2612 self.clear();
2613 let void_type = self.context.void_type();
2614 let fn_type = void_type.fn_type(
2615 &[
2616 self.real_ptr_type.into(),
2617 self.real_ptr_type.into(),
2618 self.int_type.into(),
2619 self.int_type.into(),
2620 ],
2621 false,
2622 );
2623 let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2624 let function = self.module.add_function("set_u0", fn_type, None);
2625
2626 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2628 let noalign = self.context.create_enum_attribute(alias_id, 0);
2629 for i in &[0, 1] {
2630 function.add_attribute(AttributeLoc::Param(*i), noalign);
2631 }
2632
2633 let basic_block = self.context.append_basic_block(function, "entry");
2634 self.fn_value_opt = Some(function);
2635 self.builder.position_at_end(basic_block);
2636
2637 for (i, arg) in function.get_param_iter().enumerate() {
2638 let name = fn_arg_names[i];
2639 let alloca = self.function_arg_alloca(name, arg);
2640 self.insert_param(name, alloca);
2641 }
2642
2643 self.insert_data(model);
2644 self.insert_indices();
2645
2646 let mut nbarriers = 0;
2647 let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2648 #[allow(clippy::explicit_counter_loop)]
2649 for a in model.input_dep_defns() {
2650 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2651 self.jit_compile_call_barrier(nbarriers, total_barriers);
2652 nbarriers += 1;
2653 }
2654
2655 self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?;
2656 self.jit_compile_call_barrier(nbarriers, total_barriers);
2657
2658 self.builder.build_return(None)?;
2659
2660 if function.verify(true) {
2661 Ok(function)
2662 } else {
2663 function.print_to_stderr();
2664 unsafe {
2665 function.delete();
2666 }
2667 Err(anyhow!("Invalid generated function."))
2668 }
2669 }
2670
2671 pub fn compile_calc_out<'m>(
2672 &mut self,
2673 model: &'m DiscreteModel,
2674 include_constants: bool,
2675 ) -> Result<FunctionValue<'ctx>> {
2676 self.clear();
2677 let void_type = self.context.void_type();
2678 let fn_type = void_type.fn_type(
2679 &[
2680 self.real_type.into(),
2681 self.real_ptr_type.into(),
2682 self.real_ptr_type.into(),
2683 self.real_ptr_type.into(),
2684 self.int_type.into(),
2685 self.int_type.into(),
2686 ],
2687 false,
2688 );
2689 let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2690 let function_name = if include_constants {
2691 "calc_out_full"
2692 } else {
2693 "calc_out"
2694 };
2695 let function = self.module.add_function(function_name, fn_type, None);
2696
2697 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2699 let noalign = self.context.create_enum_attribute(alias_id, 0);
2700 for i in &[1, 2] {
2701 function.add_attribute(AttributeLoc::Param(*i), noalign);
2702 }
2703
2704 let basic_block = self.context.append_basic_block(function, "entry");
2705 self.fn_value_opt = Some(function);
2706 self.builder.position_at_end(basic_block);
2707
2708 for (i, arg) in function.get_param_iter().enumerate() {
2709 let name = fn_arg_names[i];
2710 let alloca = self.function_arg_alloca(name, arg);
2711 self.insert_param(name, alloca);
2712 }
2713
2714 self.insert_state(model.state());
2715 self.insert_data(model);
2716 self.insert_indices();
2717
2718 if let Some(out) = model.out() {
2724 let mut nbarriers = 0;
2725 let mut total_barriers =
2726 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2727 if include_constants {
2728 total_barriers += model.input_dep_defns().len() as u64;
2729 for tensor in model.input_dep_defns() {
2731 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2732 self.jit_compile_call_barrier(nbarriers, total_barriers);
2733 nbarriers += 1;
2734 }
2735 }
2736
2737 for tensor in model.time_dep_defns() {
2739 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2740 self.jit_compile_call_barrier(nbarriers, total_barriers);
2741 nbarriers += 1;
2742 }
2743
2744 #[allow(clippy::explicit_counter_loop)]
2746 for a in model.state_dep_defns() {
2747 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2748 self.jit_compile_call_barrier(nbarriers, total_barriers);
2749 nbarriers += 1;
2750 }
2751
2752 self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())))?;
2753 self.jit_compile_call_barrier(nbarriers, total_barriers);
2754 }
2755 self.builder.build_return(None)?;
2756
2757 if function.verify(true) {
2758 Ok(function)
2759 } else {
2760 function.print_to_stderr();
2761 unsafe {
2762 function.delete();
2763 }
2764 Err(anyhow!("Invalid generated function."))
2765 }
2766 }
2767
2768 pub fn compile_calc_stop<'m>(
2769 &mut self,
2770 model: &'m DiscreteModel,
2771 ) -> Result<FunctionValue<'ctx>> {
2772 self.clear();
2773 let void_type = self.context.void_type();
2774 let fn_type = void_type.fn_type(
2775 &[
2776 self.real_type.into(),
2777 self.real_ptr_type.into(),
2778 self.real_ptr_type.into(),
2779 self.real_ptr_type.into(),
2780 self.int_type.into(),
2781 self.int_type.into(),
2782 ],
2783 false,
2784 );
2785 let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2786 let function = self.module.add_function("calc_stop", fn_type, None);
2787
2788 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2790 let noalign = self.context.create_enum_attribute(alias_id, 0);
2791 for i in &[1, 2, 3] {
2792 function.add_attribute(AttributeLoc::Param(*i), noalign);
2793 }
2794
2795 let basic_block = self.context.append_basic_block(function, "entry");
2796 self.fn_value_opt = Some(function);
2797 self.builder.position_at_end(basic_block);
2798
2799 for (i, arg) in function.get_param_iter().enumerate() {
2800 let name = fn_arg_names[i];
2801 let alloca = self.function_arg_alloca(name, arg);
2802 self.insert_param(name, alloca);
2803 }
2804
2805 self.insert_state(model.state());
2806 self.insert_data(model);
2807 self.insert_indices();
2808
2809 if let Some(stop) = model.stop() {
2810 let mut nbarriers = 0;
2812 let total_barriers =
2813 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2814 for tensor in model.time_dep_defns() {
2815 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2816 self.jit_compile_call_barrier(nbarriers, total_barriers);
2817 nbarriers += 1;
2818 }
2819
2820 for a in model.state_dep_defns() {
2822 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2823 self.jit_compile_call_barrier(nbarriers, total_barriers);
2824 nbarriers += 1;
2825 }
2826
2827 let res_ptr = self.get_param("root");
2828 self.jit_compile_tensor(stop, Some(*res_ptr))?;
2829 self.jit_compile_call_barrier(nbarriers, total_barriers);
2830 }
2831 self.builder.build_return(None)?;
2832
2833 if function.verify(true) {
2834 Ok(function)
2835 } else {
2836 function.print_to_stderr();
2837 unsafe {
2838 function.delete();
2839 }
2840 Err(anyhow!("Invalid generated function."))
2841 }
2842 }
2843
2844 pub fn compile_rhs<'m>(
2845 &mut self,
2846 model: &'m DiscreteModel,
2847 include_constants: bool,
2848 ) -> Result<FunctionValue<'ctx>> {
2849 self.clear();
2850 let void_type = self.context.void_type();
2851 let fn_type = void_type.fn_type(
2852 &[
2853 self.real_type.into(),
2854 self.real_ptr_type.into(),
2855 self.real_ptr_type.into(),
2856 self.real_ptr_type.into(),
2857 self.int_type.into(),
2858 self.int_type.into(),
2859 ],
2860 false,
2861 );
2862 let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
2863 let function_name = if include_constants { "rhs_full" } else { "rhs" };
2864 let function = self.module.add_function(function_name, fn_type, None);
2865
2866 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2868 let noalign = self.context.create_enum_attribute(alias_id, 0);
2869 for i in &[1, 2, 3] {
2870 function.add_attribute(AttributeLoc::Param(*i), noalign);
2871 }
2872
2873 let basic_block = self.context.append_basic_block(function, "entry");
2874 self.fn_value_opt = Some(function);
2875 self.builder.position_at_end(basic_block);
2876
2877 for (i, arg) in function.get_param_iter().enumerate() {
2878 let name = fn_arg_names[i];
2879 let alloca = self.function_arg_alloca(name, arg);
2880 self.insert_param(name, alloca);
2881 }
2882
2883 self.insert_state(model.state());
2884 self.insert_data(model);
2885 self.insert_indices();
2886
2887 let mut nbarriers = 0;
2888 let mut total_barriers =
2889 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2890 if include_constants {
2891 total_barriers += model.input_dep_defns().len() as u64;
2892 for tensor in model.input_dep_defns() {
2894 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2895 self.jit_compile_call_barrier(nbarriers, total_barriers);
2896 nbarriers += 1;
2897 }
2898 }
2899
2900 for tensor in model.time_dep_defns() {
2902 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2903 self.jit_compile_call_barrier(nbarriers, total_barriers);
2904 nbarriers += 1;
2905 }
2906
2907 for a in model.state_dep_defns() {
2909 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2910 self.jit_compile_call_barrier(nbarriers, total_barriers);
2911 nbarriers += 1;
2912 }
2913
2914 let res_ptr = self.get_param("rr");
2916 self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?;
2917 self.jit_compile_call_barrier(nbarriers, total_barriers);
2918
2919 self.builder.build_return(None)?;
2920
2921 if function.verify(true) {
2922 Ok(function)
2923 } else {
2924 function.print_to_stderr();
2925 unsafe {
2926 function.delete();
2927 }
2928 Err(anyhow!("Invalid generated function."))
2929 }
2930 }
2931
2932 pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2933 self.clear();
2934 let void_type = self.context.void_type();
2935 let fn_type = void_type.fn_type(
2936 &[
2937 self.real_type.into(),
2938 self.real_ptr_type.into(),
2939 self.real_ptr_type.into(),
2940 self.real_ptr_type.into(),
2941 self.int_type.into(),
2942 self.int_type.into(),
2943 ],
2944 false,
2945 );
2946 let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
2947 let function = self.module.add_function("mass", fn_type, None);
2948
2949 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2951 let noalign = self.context.create_enum_attribute(alias_id, 0);
2952 for i in &[1, 2, 3] {
2953 function.add_attribute(AttributeLoc::Param(*i), noalign);
2954 }
2955
2956 let basic_block = self.context.append_basic_block(function, "entry");
2957 self.fn_value_opt = Some(function);
2958 self.builder.position_at_end(basic_block);
2959
2960 for (i, arg) in function.get_param_iter().enumerate() {
2961 let name = fn_arg_names[i];
2962 let alloca = self.function_arg_alloca(name, arg);
2963 self.insert_param(name, alloca);
2964 }
2965
2966 if model.state_dot().is_some() && model.lhs().is_some() {
2968 let state_dot = model.state_dot().unwrap();
2969 let lhs = model.lhs().unwrap();
2970
2971 self.insert_dot_state(state_dot);
2972 self.insert_data(model);
2973 self.insert_indices();
2974
2975 let mut nbarriers = 0;
2977 let total_barriers =
2978 (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
2979 for tensor in model.time_dep_defns() {
2980 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2981 self.jit_compile_call_barrier(nbarriers, total_barriers);
2982 nbarriers += 1;
2983 }
2984
2985 for a in model.dstate_dep_defns() {
2986 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2987 self.jit_compile_call_barrier(nbarriers, total_barriers);
2988 nbarriers += 1;
2989 }
2990
2991 let res_ptr = self.get_param("rr");
2993 self.jit_compile_tensor(lhs, Some(*res_ptr))?;
2994 self.jit_compile_call_barrier(nbarriers, total_barriers);
2995 }
2996
2997 self.builder.build_return(None)?;
2998
2999 if function.verify(true) {
3000 Ok(function)
3001 } else {
3002 function.print_to_stderr();
3003 unsafe {
3004 function.delete();
3005 }
3006 Err(anyhow!("Invalid generated function."))
3007 }
3008 }
3009
3010 pub fn compile_gradient(
3011 &mut self,
3012 original_function: FunctionValue<'ctx>,
3013 args_type: &[CompileGradientArgType],
3014 mode: CompileMode,
3015 fn_name: &str,
3016 ) -> Result<FunctionValue<'ctx>> {
3017 self.clear();
3018
3019 let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
3021
3022 let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
3023
3024 let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
3025 let mut start_param_index: Vec<u32> = Vec::new();
3026 let mut ptr_arg_indices: Vec<u32> = Vec::new();
3027 for (i, arg) in original_function.get_param_iter().enumerate() {
3028 start_param_index.push(u32::try_from(fn_type.len()).unwrap());
3029 let arg_type = arg.get_type();
3030 fn_type.push(arg_type.into());
3031
3032 enzyme_fn_type.push(self.int_type.into());
3034 enzyme_fn_type.push(arg.get_type().into());
3035
3036 if arg_type.is_pointer_type() {
3037 ptr_arg_indices.push(u32::try_from(i).unwrap());
3038 }
3039
3040 match args_type[i] {
3041 CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
3042 fn_type.push(arg.get_type().into());
3043 enzyme_fn_type.push(arg.get_type().into());
3044 }
3045 CompileGradientArgType::Const => {}
3046 }
3047 }
3048 let void_type = self.context.void_type();
3049 let fn_type = void_type.fn_type(fn_type.as_slice(), false);
3050 let function = self.module.add_function(fn_name, fn_type, None);
3051
3052 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3054 let noalign = self.context.create_enum_attribute(alias_id, 0);
3055 for i in ptr_arg_indices {
3056 function.add_attribute(AttributeLoc::Param(i), noalign);
3057 }
3058
3059 let basic_block = self.context.append_basic_block(function, "entry");
3060 self.fn_value_opt = Some(function);
3061 self.builder.position_at_end(basic_block);
3062
3063 let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
3064 let mut input_activity = Vec::new();
3065 let mut arg_trees = Vec::new();
3066 for (i, arg) in original_function.get_param_iter().enumerate() {
3067 let param_index = start_param_index[i];
3068 let fn_arg = function.get_nth_param(param_index).unwrap();
3069
3070 let concrete_type = match arg.get_type() {
3073 BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
3074 BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
3075 RealType::F32 => CConcreteType_DT_Float,
3076 RealType::F64 => CConcreteType_DT_Double,
3077 },
3078 BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
3079 _ => panic!("unsupported type"),
3080 };
3081 let new_tree = unsafe {
3082 EnzymeNewTypeTreeCT(
3083 concrete_type,
3084 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3085 )
3086 };
3087 unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3088
3089 if concrete_type == CConcreteType_DT_Pointer {
3091 let inner_concrete_type = match self.diffsl_real_type {
3092 RealType::F32 => CConcreteType_DT_Float,
3093 RealType::F64 => CConcreteType_DT_Double,
3094 };
3095 let inner_new_tree = unsafe {
3096 EnzymeNewTypeTreeCT(
3097 inner_concrete_type,
3098 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3099 )
3100 };
3101 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3102 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
3103 unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
3104 }
3105
3106 arg_trees.push(new_tree);
3107 match args_type[i] {
3108 CompileGradientArgType::Dup => {
3109 enzyme_fn_args.push(fn_arg.into());
3111
3112 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3114 enzyme_fn_args.push(fn_darg.into());
3115
3116 input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3117 }
3118 CompileGradientArgType::DupNoNeed => {
3119 enzyme_fn_args.push(fn_arg.into());
3121
3122 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3124 enzyme_fn_args.push(fn_darg.into());
3125
3126 input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3127 }
3128 CompileGradientArgType::Const => {
3129 enzyme_fn_args.push(fn_arg.into());
3131
3132 input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3133 }
3134 }
3135 }
3136 let ret_primary_ret = false;
3138 let diff_ret = false;
3139 let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3140 let ret_tree = unsafe {
3141 EnzymeNewTypeTreeCT(
3142 CConcreteType_DT_Anything,
3143 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3144 )
3145 };
3146
3147 let fnc_opt_base = true;
3149 let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3150
3151 let kv_tmp = IntList {
3152 data: std::ptr::null_mut(),
3153 size: 0,
3154 };
3155 let mut known_values = vec![kv_tmp; input_activity.len()];
3156
3157 let fn_type_info = CFnTypeInfo {
3158 Arguments: arg_trees.as_mut_ptr(),
3159 Return: ret_tree,
3160 KnownValues: known_values.as_mut_ptr(),
3161 };
3162
3163 let type_analysis: EnzymeTypeAnalysisRef =
3164 unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3165
3166 let mut args_uncacheable = vec![0; arg_trees.len()];
3167
3168 let enzyme_function = match mode {
3169 CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3170 EnzymeCreateForwardDiff(
3171 logic_ref, std::ptr::null_mut(),
3173 std::ptr::null_mut(),
3174 original_function.as_value_ref(),
3175 ret_activity, input_activity.as_mut_ptr(),
3177 input_activity.len(), type_analysis, ret_primary_ret as u8,
3180 CDerivativeMode_DEM_ForwardMode, 1, 0, 0, 1, std::ptr::null_mut(), fn_type_info, 1, args_uncacheable.as_mut_ptr(), args_uncacheable.len(), std::ptr::null_mut(), )
3192 },
3193 CompileMode::Reverse | CompileMode::ReverseSens => {
3194 let mut call_enzyme = || unsafe {
3195 EnzymeCreatePrimalAndGradient(
3196 logic_ref,
3197 std::ptr::null_mut(),
3198 std::ptr::null_mut(),
3199 original_function.as_value_ref(),
3200 ret_activity,
3201 input_activity.as_mut_ptr(),
3202 input_activity.len(),
3203 type_analysis,
3204 ret_primary_ret as u8,
3205 diff_ret as u8,
3206 CDerivativeMode_DEM_ReverseModeCombined,
3207 0,
3208 0, 1,
3210 1,
3211 std::ptr::null_mut(),
3212 0, fn_type_info,
3214 0, args_uncacheable.as_mut_ptr(),
3216 args_uncacheable.len(),
3217 std::ptr::null_mut(),
3218 if self.threaded { 1 } else { 0 }, )
3220 };
3221 if self.threaded {
3222 let _lock = my_mutex.lock().unwrap();
3224 let barrier_string = CString::new("barrier").unwrap();
3225 unsafe {
3226 EnzymeRegisterCallHandler(
3227 barrier_string.as_ptr(),
3228 Some(fwd_handler),
3229 Some(rev_handler),
3230 )
3231 };
3232 let ret = call_enzyme();
3233 unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3235 ret
3236 } else {
3237 call_enzyme()
3238 }
3239 }
3240 };
3241
3242 unsafe { FreeEnzymeLogic(logic_ref) };
3244 unsafe { FreeTypeAnalysis(type_analysis) };
3245 unsafe { EnzymeFreeTypeTree(ret_tree) };
3246 for tree in arg_trees {
3247 unsafe { EnzymeFreeTypeTree(tree) };
3248 }
3249
3250 let enzyme_function =
3252 unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3253 self.builder
3254 .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3255
3256 self.builder.build_return(None)?;
3258
3259 if function.verify(true) {
3260 Ok(function)
3261 } else {
3262 function.print_to_stderr();
3263 enzyme_function.print_to_stderr();
3264 unsafe {
3265 function.delete();
3266 }
3267 Err(anyhow!("Invalid generated function."))
3268 }
3269 }
3270
3271 pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3272 self.clear();
3273 let fn_type = self.context.void_type().fn_type(
3274 &[
3275 self.int_ptr_type.into(),
3276 self.int_ptr_type.into(),
3277 self.int_ptr_type.into(),
3278 self.int_ptr_type.into(),
3279 self.int_ptr_type.into(),
3280 self.int_ptr_type.into(),
3281 ],
3282 false,
3283 );
3284
3285 let function = self.module.add_function("get_dims", fn_type, None);
3286 let block = self.context.append_basic_block(function, "entry");
3287 let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3288 self.builder.position_at_end(block);
3289
3290 for (i, arg) in function.get_param_iter().enumerate() {
3291 let name = fn_arg_names[i];
3292 let alloca = self.function_arg_alloca(name, arg);
3293 self.insert_param(name, alloca);
3294 }
3295
3296 self.insert_indices();
3297
3298 let number_of_states = model.state().nnz() as u64;
3299 let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
3300 let number_of_outputs = match model.out() {
3301 Some(out) => out.nnz() as u64,
3302 None => 0,
3303 };
3304 let number_of_stop = if let Some(stop) = model.stop() {
3305 stop.nnz() as u64
3306 } else {
3307 0
3308 };
3309 let has_mass = match model.lhs().is_some() {
3310 true => 1u64,
3311 false => 0u64,
3312 };
3313 let data_len = self.layout.data().len() as u64;
3314 self.builder.build_store(
3315 *self.get_param("states"),
3316 self.int_type.const_int(number_of_states, false),
3317 )?;
3318 self.builder.build_store(
3319 *self.get_param("inputs"),
3320 self.int_type.const_int(number_of_inputs, false),
3321 )?;
3322 self.builder.build_store(
3323 *self.get_param("outputs"),
3324 self.int_type.const_int(number_of_outputs, false),
3325 )?;
3326 self.builder.build_store(
3327 *self.get_param("data"),
3328 self.int_type.const_int(data_len, false),
3329 )?;
3330 self.builder.build_store(
3331 *self.get_param("stop"),
3332 self.int_type.const_int(number_of_stop, false),
3333 )?;
3334 self.builder.build_store(
3335 *self.get_param("has_mass"),
3336 self.int_type.const_int(has_mass, false),
3337 )?;
3338 self.builder.build_return(None)?;
3339
3340 if function.verify(true) {
3341 Ok(function)
3342 } else {
3343 function.print_to_stderr();
3344 unsafe {
3345 function.delete();
3346 }
3347 Err(anyhow!("Invalid generated function."))
3348 }
3349 }
3350
3351 pub fn compile_get_tensor(
3352 &mut self,
3353 model: &DiscreteModel,
3354 name: &str,
3355 ) -> Result<FunctionValue<'ctx>> {
3356 self.clear();
3357 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3358 let fn_type = self.context.void_type().fn_type(
3359 &[
3360 self.real_ptr_type.into(),
3361 real_ptr_ptr_type.into(),
3362 self.int_ptr_type.into(),
3363 ],
3364 false,
3365 );
3366 let function_name = format!("get_tensor_{name}");
3367 let function = self
3368 .module
3369 .add_function(function_name.as_str(), fn_type, None);
3370 let basic_block = self.context.append_basic_block(function, "entry");
3371 self.fn_value_opt = Some(function);
3372
3373 let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3374 self.builder.position_at_end(basic_block);
3375
3376 for (i, arg) in function.get_param_iter().enumerate() {
3377 let name = fn_arg_names[i];
3378 let alloca = self.function_arg_alloca(name, arg);
3379 self.insert_param(name, alloca);
3380 }
3381
3382 self.insert_data(model);
3383 let ptr = self.get_param(name);
3384 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3385 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3386 self.builder
3387 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3388 self.builder
3389 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3390 self.builder.build_return(None)?;
3391
3392 if function.verify(true) {
3393 Ok(function)
3394 } else {
3395 function.print_to_stderr();
3396 unsafe {
3397 function.delete();
3398 }
3399 Err(anyhow!("Invalid generated function."))
3400 }
3401 }
3402
3403 pub fn compile_get_constant(
3404 &mut self,
3405 model: &DiscreteModel,
3406 name: &str,
3407 ) -> Result<FunctionValue<'ctx>> {
3408 self.clear();
3409 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3410 let fn_type = self
3411 .context
3412 .void_type()
3413 .fn_type(&[real_ptr_ptr_type.into(), self.int_ptr_type.into()], false);
3414 let function_name = format!("get_constant_{name}");
3415 let function = self
3416 .module
3417 .add_function(function_name.as_str(), fn_type, None);
3418 let basic_block = self.context.append_basic_block(function, "entry");
3419 self.fn_value_opt = Some(function);
3420
3421 let fn_arg_names = &["tensor_data", "tensor_size"];
3422 self.builder.position_at_end(basic_block);
3423
3424 for (i, arg) in function.get_param_iter().enumerate() {
3425 let name = fn_arg_names[i];
3426 let alloca = self.function_arg_alloca(name, arg);
3427 self.insert_param(name, alloca);
3428 }
3429
3430 self.insert_constants(model);
3431 let ptr = self.get_param(name);
3432 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3433 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3434 self.builder
3435 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3436 self.builder
3437 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3438 self.builder.build_return(None)?;
3439
3440 if function.verify(true) {
3441 Ok(function)
3442 } else {
3443 function.print_to_stderr();
3444 unsafe {
3445 function.delete();
3446 }
3447 Err(anyhow!("Invalid generated function."))
3448 }
3449 }
3450
3451 pub fn compile_inputs(
3452 &mut self,
3453 model: &DiscreteModel,
3454 is_get: bool,
3455 ) -> Result<FunctionValue<'ctx>> {
3456 self.clear();
3457 let void_type = self.context.void_type();
3458 let fn_type = void_type.fn_type(
3459 &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3460 false,
3461 );
3462 let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3463 let function = self.module.add_function(function_name, fn_type, None);
3464 let block = self.context.append_basic_block(function, "entry");
3465 self.fn_value_opt = Some(function);
3466
3467 let fn_arg_names = &["inputs", "data"];
3468 self.builder.position_at_end(block);
3469
3470 for (i, arg) in function.get_param_iter().enumerate() {
3471 let name = fn_arg_names[i];
3472 let alloca = self.function_arg_alloca(name, arg);
3473 self.insert_param(name, alloca);
3474 }
3475
3476 if let Some(input) = model.input() {
3477 let name = input.name();
3478 self.insert_tensor(input, false);
3479 let ptr = self.get_var(input);
3480 let inputs_start_index = self.int_type.const_int(0, false);
3482 let start_index = self.int_type.const_int(0, false);
3483 let end_index = self
3484 .int_type
3485 .const_int(input.nnz().try_into().unwrap(), false);
3486
3487 let input_block = self.context.append_basic_block(function, name);
3488 self.builder.build_unconditional_branch(input_block)?;
3489 self.builder.position_at_end(input_block);
3490 let index = self.builder.build_phi(self.int_type, "i")?;
3491 index.add_incoming(&[(&start_index, block)]);
3492
3493 let curr_input_index = index.as_basic_value().into_int_value();
3495 let input_ptr =
3496 Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
3497 let curr_inputs_index =
3498 self.builder
3499 .build_int_add(inputs_start_index, curr_input_index, name)?;
3500 let inputs_ptr = Self::get_ptr_to_index(
3501 &self.builder,
3502 self.real_type,
3503 self.get_param("inputs"),
3504 curr_inputs_index,
3505 name,
3506 );
3507 if is_get {
3508 let input_value = self
3509 .build_load(self.real_type, input_ptr, name)?
3510 .into_float_value();
3511 self.builder.build_store(inputs_ptr, input_value)?;
3512 } else {
3513 let input_value = self
3514 .build_load(self.real_type, inputs_ptr, name)?
3515 .into_float_value();
3516 self.builder.build_store(input_ptr, input_value)?;
3517 }
3518
3519 let one = self.int_type.const_int(1, false);
3521 let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
3522 index.add_incoming(&[(&next_index, input_block)]);
3523
3524 let loop_while =
3526 self.builder
3527 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
3528 let post_block = self.context.append_basic_block(function, name);
3529 self.builder
3530 .build_conditional_branch(loop_while, input_block, post_block)?;
3531 self.builder.position_at_end(post_block);
3532 }
3533 self.builder.build_return(None)?;
3534
3535 if function.verify(true) {
3536 Ok(function)
3537 } else {
3538 function.print_to_stderr();
3539 unsafe {
3540 function.delete();
3541 }
3542 Err(anyhow!("Invalid generated function."))
3543 }
3544 }
3545
3546 pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3547 self.clear();
3548 let void_type = self.context.void_type();
3549 let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false);
3550 let function = self.module.add_function("set_id", fn_type, None);
3551 let mut block = self.context.append_basic_block(function, "entry");
3552
3553 let fn_arg_names = &["id"];
3554 self.builder.position_at_end(block);
3555
3556 for (i, arg) in function.get_param_iter().enumerate() {
3557 let name = fn_arg_names[i];
3558 let alloca = self.function_arg_alloca(name, arg);
3559 self.insert_param(name, alloca);
3560 }
3561
3562 let mut id_index = 0usize;
3563 for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3564 let name = blk.name().unwrap_or("unknown");
3565 let id_start_index = self.int_type.const_int(id_index as u64, false);
3567 let blk_start_index = self.int_type.const_int(0, false);
3568 let blk_end_index = self
3569 .int_type
3570 .const_int(blk.nnz().try_into().unwrap(), false);
3571
3572 let blk_block = self.context.append_basic_block(function, name);
3573 self.builder.build_unconditional_branch(blk_block)?;
3574 self.builder.position_at_end(blk_block);
3575 let index = self.builder.build_phi(self.int_type, "i")?;
3576 index.add_incoming(&[(&blk_start_index, block)]);
3577
3578 let curr_blk_index = index.as_basic_value().into_int_value();
3580 let curr_id_index = self
3581 .builder
3582 .build_int_add(id_start_index, curr_blk_index, name)?;
3583 let id_ptr = Self::get_ptr_to_index(
3584 &self.builder,
3585 self.real_type,
3586 self.get_param("id"),
3587 curr_id_index,
3588 name,
3589 );
3590 let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
3591 let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3592 self.builder.build_store(id_ptr, is_algebraic_value)?;
3593
3594 let one = self.int_type.const_int(1, false);
3596 let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3597 index.add_incoming(&[(&next_index, blk_block)]);
3598
3599 let loop_while = self.builder.build_int_compare(
3601 IntPredicate::ULT,
3602 next_index,
3603 blk_end_index,
3604 name,
3605 )?;
3606 let post_block = self.context.append_basic_block(function, name);
3607 self.builder
3608 .build_conditional_branch(loop_while, blk_block, post_block)?;
3609 self.builder.position_at_end(post_block);
3610
3611 block = post_block;
3613 id_index += blk.nnz();
3614 }
3615 self.builder.build_return(None)?;
3616
3617 if function.verify(true) {
3618 Ok(function)
3619 } else {
3620 function.print_to_stderr();
3621 unsafe {
3622 function.delete();
3623 }
3624 Err(anyhow!("Invalid generated function."))
3625 }
3626 }
3627
3628 pub fn module(&self) -> &Module<'ctx> {
3629 &self.module
3630 }
3631}