1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use inkwell::attributes::{Attribute, AttributeLoc};
4use inkwell::basic_block::BasicBlock;
5use inkwell::builder::Builder;
6use inkwell::context::{AsContextRef, Context};
7use inkwell::execution_engine::ExecutionEngine;
8use inkwell::intrinsics::Intrinsic;
9use inkwell::module::{Linkage, Module};
10use inkwell::passes::PassBuilderOptions;
11use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
12use inkwell::types::{
13 BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
14};
15use inkwell::values::{
16 AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
17 FunctionValue, GlobalValue, IntValue, PointerValue,
18};
19use inkwell::{
20 AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
21 OptimizationLevel,
22};
23use llvm_sys::core::{
24 LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
25 LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType,
26};
27use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
28use std::collections::HashMap;
29use std::ffi::CString;
30use std::iter::zip;
31use std::pin::Pin;
32use target_lexicon::Triple;
33
34type RealType = f64;
35
36use crate::ast::{Ast, AstKind};
37use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
38use crate::enzyme::{
39 CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Integer,
40 CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
41 CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
42 DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
43 EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
44 EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
45 FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
46 CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
47};
48use crate::execution::compiler::CompilerMode;
49use crate::execution::module::{
50 CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
51};
52use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
53use lazy_static::lazy_static;
54use std::sync::Mutex;
55
56lazy_static! {
57 static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
58}
59
60struct ImmovableLlvmModule {
61 codegen: Option<CodeGen<'static>>,
64 context: AliasableBox<Context>,
66 _pin: std::marker::PhantomPinned,
67}
68
69pub struct LlvmModule {
70 inner: Pin<Box<ImmovableLlvmModule>>,
71 machine: TargetMachine,
72}
73
74unsafe impl Sync for LlvmModule {}
75
76impl LlvmModule {
77 fn new(triple: Option<Triple>, model: &DiscreteModel, threaded: bool) -> Result<Self> {
78 let initialization_config = &InitializationConfig::default();
79 Target::initialize_all(initialization_config);
80 let host_triple = Triple::host();
81 let (triple_str, native) = match triple {
82 Some(ref triple) => (triple.to_string(), false),
83 None => (host_triple.to_string(), true),
84 };
85 let triple = TargetTriple::create(triple_str.as_str());
86 let target = Target::from_triple(&triple).unwrap();
87 let cpu = if native {
88 TargetMachine::get_host_cpu_name().to_string()
89 } else {
90 "generic".to_string()
91 };
92 let features = if native {
93 TargetMachine::get_host_cpu_features().to_string()
94 } else {
95 "".to_string()
96 };
97 let machine = target
98 .create_target_machine(
99 &triple,
100 cpu.as_str(),
101 features.as_str(),
102 inkwell::OptimizationLevel::Aggressive,
103 inkwell::targets::RelocMode::Default,
104 inkwell::targets::CodeModel::Default,
105 )
106 .unwrap();
107
108 let context = AliasableBox::from_unique(Box::new(Context::create()));
109 let mut pinned = Self {
110 inner: Box::pin(ImmovableLlvmModule {
111 codegen: None,
112 context,
113 _pin: std::marker::PhantomPinned,
114 }),
115 machine,
116 };
117
118 let context_ref = pinned.inner.context.as_ref();
119 let real_type_str = "f64";
120 let codegen = CodeGen::new(
121 model,
122 context_ref,
123 context_ref.f64_type(),
124 context_ref.i32_type(),
125 real_type_str,
126 threaded,
127 )?;
128 let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
129 unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
130 Ok(pinned)
131 }
132
133 fn pre_autodiff_optimisation(&mut self) -> Result<()> {
134 let pass_options = PassBuilderOptions::create();
143 pass_options.set_loop_vectorization(false);
147 pass_options.set_loop_slp_vectorization(false);
148 pass_options.set_loop_unrolling(false);
149 let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
157 let (codegen, machine) = self.codegen_and_machine_mut();
158 codegen
159 .module()
160 .run_passes(passes, machine, pass_options)
161 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
162 }
163
164 fn post_autodiff_optimisation(&mut self) -> Result<()> {
165 if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
167 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
168 barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
169 }
170
171 for f in self.codegen_mut().module.get_functions() {
173 if f.get_name().to_str().unwrap().starts_with("preprocess_") {
174 unsafe { f.delete() };
175 }
176 }
177
178 let passes = "default<O3>";
184 let (codegen, machine) = self.codegen_and_machine_mut();
185 codegen
186 .module()
187 .run_passes(passes, machine, PassBuilderOptions::create())
188 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
189
190 Ok(())
191 }
192
193 pub fn print(&self) {
194 self.codegen().module().print_to_stderr();
195 }
196 fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
197 unsafe {
198 self.inner
199 .as_mut()
200 .get_unchecked_mut()
201 .codegen
202 .as_mut()
203 .unwrap()
204 }
205 }
206 fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
207 (
208 unsafe {
209 self.inner
210 .as_mut()
211 .get_unchecked_mut()
212 .codegen
213 .as_mut()
214 .unwrap()
215 },
216 &self.machine,
217 )
218 }
219
220 fn codegen(&self) -> &CodeGen<'static> {
221 self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
222 }
223}
224
225impl CodegenModule for LlvmModule {}
226
227impl CodegenModuleCompile for LlvmModule {
228 fn from_discrete_model(
229 model: &DiscreteModel,
230 mode: CompilerMode,
231 triple: Option<Triple>,
232 ) -> Result<Self> {
233 let thread_dim = mode.thread_dim(model.state().nnz());
234 let threaded = thread_dim > 1;
235
236 let mut module = Self::new(triple, model, threaded)?;
237
238 let set_u0 = module.codegen_mut().compile_set_u0(model)?;
239 let _calc_stop = module.codegen_mut().compile_calc_stop(model)?;
240 let rhs = module.codegen_mut().compile_rhs(model, false)?;
241 let rhs_full = module.codegen_mut().compile_rhs(model, true)?;
242 let mass = module.codegen_mut().compile_mass(model)?;
243 let calc_out = module.codegen_mut().compile_calc_out(model, false)?;
244 let calc_out_full = module.codegen_mut().compile_calc_out(model, true)?;
245 let _set_id = module.codegen_mut().compile_set_id(model)?;
246 let _get_dims = module.codegen_mut().compile_get_dims(model)?;
247 let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
248 let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
249 let _set_constants = module.codegen_mut().compile_set_constants(model)?;
250 let tensor_info = module
251 .codegen()
252 .layout
253 .tensors()
254 .map(|(name, is_constant)| (name.to_string(), is_constant))
255 .collect::<Vec<_>>();
256 for (tensor, is_constant) in tensor_info {
257 if is_constant {
258 module
259 .codegen_mut()
260 .compile_get_constant(model, tensor.as_str())?;
261 } else {
262 module
263 .codegen_mut()
264 .compile_get_tensor(model, tensor.as_str())?;
265 }
266 }
267
268 module.pre_autodiff_optimisation()?;
269
270 module.codegen_mut().compile_gradient(
271 set_u0,
272 &[
273 CompileGradientArgType::DupNoNeed,
274 CompileGradientArgType::DupNoNeed,
275 CompileGradientArgType::Const,
276 CompileGradientArgType::Const,
277 ],
278 CompileMode::Forward,
279 "set_u0_grad",
280 )?;
281
282 module.codegen_mut().compile_gradient(
283 rhs,
284 &[
285 CompileGradientArgType::Const,
286 CompileGradientArgType::DupNoNeed,
287 CompileGradientArgType::DupNoNeed,
288 CompileGradientArgType::DupNoNeed,
289 CompileGradientArgType::Const,
290 CompileGradientArgType::Const,
291 ],
292 CompileMode::Forward,
293 "rhs_grad",
294 )?;
295
296 module.codegen_mut().compile_gradient(
297 calc_out,
298 &[
299 CompileGradientArgType::Const,
300 CompileGradientArgType::DupNoNeed,
301 CompileGradientArgType::DupNoNeed,
302 CompileGradientArgType::DupNoNeed,
303 CompileGradientArgType::Const,
304 CompileGradientArgType::Const,
305 ],
306 CompileMode::Forward,
307 "calc_out_grad",
308 )?;
309 module.codegen_mut().compile_gradient(
310 set_inputs,
311 &[
312 CompileGradientArgType::DupNoNeed,
313 CompileGradientArgType::DupNoNeed,
314 ],
315 CompileMode::Forward,
316 "set_inputs_grad",
317 )?;
318
319 module.codegen_mut().compile_gradient(
320 set_u0,
321 &[
322 CompileGradientArgType::DupNoNeed,
323 CompileGradientArgType::DupNoNeed,
324 CompileGradientArgType::Const,
325 CompileGradientArgType::Const,
326 ],
327 CompileMode::Reverse,
328 "set_u0_rgrad",
329 )?;
330
331 module.codegen_mut().compile_gradient(
332 mass,
333 &[
334 CompileGradientArgType::Const,
335 CompileGradientArgType::DupNoNeed,
336 CompileGradientArgType::DupNoNeed,
337 CompileGradientArgType::DupNoNeed,
338 CompileGradientArgType::Const,
339 CompileGradientArgType::Const,
340 ],
341 CompileMode::Reverse,
342 "mass_rgrad",
343 )?;
344
345 module.codegen_mut().compile_gradient(
346 rhs,
347 &[
348 CompileGradientArgType::Const,
349 CompileGradientArgType::DupNoNeed,
350 CompileGradientArgType::DupNoNeed,
351 CompileGradientArgType::DupNoNeed,
352 CompileGradientArgType::Const,
353 CompileGradientArgType::Const,
354 ],
355 CompileMode::Reverse,
356 "rhs_rgrad",
357 )?;
358 module.codegen_mut().compile_gradient(
359 calc_out,
360 &[
361 CompileGradientArgType::Const,
362 CompileGradientArgType::DupNoNeed,
363 CompileGradientArgType::DupNoNeed,
364 CompileGradientArgType::DupNoNeed,
365 CompileGradientArgType::Const,
366 CompileGradientArgType::Const,
367 ],
368 CompileMode::Reverse,
369 "calc_out_rgrad",
370 )?;
371
372 module.codegen_mut().compile_gradient(
373 set_inputs,
374 &[
375 CompileGradientArgType::DupNoNeed,
376 CompileGradientArgType::DupNoNeed,
377 ],
378 CompileMode::Reverse,
379 "set_inputs_rgrad",
380 )?;
381
382 module.codegen_mut().compile_gradient(
383 rhs_full,
384 &[
385 CompileGradientArgType::Const,
386 CompileGradientArgType::Const,
387 CompileGradientArgType::DupNoNeed,
388 CompileGradientArgType::DupNoNeed,
389 CompileGradientArgType::Const,
390 CompileGradientArgType::Const,
391 ],
392 CompileMode::ForwardSens,
393 "rhs_sgrad",
394 )?;
395
396 module.codegen_mut().compile_gradient(
397 calc_out_full,
398 &[
399 CompileGradientArgType::Const,
400 CompileGradientArgType::Const,
401 CompileGradientArgType::DupNoNeed,
402 CompileGradientArgType::DupNoNeed,
403 CompileGradientArgType::Const,
404 CompileGradientArgType::Const,
405 ],
406 CompileMode::ForwardSens,
407 "calc_out_sgrad",
408 )?;
409 module.codegen_mut().compile_gradient(
410 calc_out_full,
411 &[
412 CompileGradientArgType::Const,
413 CompileGradientArgType::Const,
414 CompileGradientArgType::DupNoNeed,
415 CompileGradientArgType::DupNoNeed,
416 CompileGradientArgType::Const,
417 CompileGradientArgType::Const,
418 ],
419 CompileMode::ReverseSens,
420 "calc_out_srgrad",
421 )?;
422
423 module.codegen_mut().compile_gradient(
424 rhs_full,
425 &[
426 CompileGradientArgType::Const,
427 CompileGradientArgType::Const,
428 CompileGradientArgType::DupNoNeed,
429 CompileGradientArgType::DupNoNeed,
430 CompileGradientArgType::Const,
431 CompileGradientArgType::Const,
432 ],
433 CompileMode::ReverseSens,
434 "rhs_srgrad",
435 )?;
436
437 module.post_autodiff_optimisation()?;
438 Ok(module)
439 }
440}
441
442impl CodegenModuleEmit for LlvmModule {
443 fn to_object(self) -> Result<Vec<u8>> {
444 let module = self.codegen().module();
445 let buffer = self
447 .machine
448 .write_to_memory_buffer(module, FileType::Object)
449 .unwrap()
450 .as_slice()
451 .to_vec();
452 Ok(buffer)
453 }
454}
455impl CodegenModuleJit for LlvmModule {
456 fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
457 let ee = self
458 .codegen()
459 .module()
460 .create_jit_execution_engine(OptimizationLevel::Aggressive)
461 .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
462
463 let module = self.codegen().module();
464 let mut symbols = HashMap::new();
465 for function in module.get_functions() {
466 let name = function.get_name().to_str().unwrap();
467 let address = ee.get_function_address(name);
468 if let Ok(address) = address {
469 symbols.insert(name.to_string(), address as *const u8);
470 }
471 }
472 Ok(symbols)
473 }
474}
475
476struct Globals<'ctx> {
477 indices: Option<GlobalValue<'ctx>>,
478 constants: Option<GlobalValue<'ctx>>,
479 thread_counter: Option<GlobalValue<'ctx>>,
480}
481
482impl<'ctx> Globals<'ctx> {
483 fn new(
484 layout: &DataLayout,
485 module: &Module<'ctx>,
486 int_type: IntType<'ctx>,
487 real_type: FloatType<'ctx>,
488 threaded: bool,
489 ) -> Self {
490 let thread_counter = if threaded {
491 let tc = module.add_global(
492 int_type,
493 Some(AddressSpace::default()),
494 "enzyme_const_thread_counter",
495 );
496 let tc_value = int_type.const_zero();
503 tc.set_visibility(GlobalVisibility::Hidden);
504 tc.set_initializer(&tc_value.as_basic_value_enum());
505 Some(tc)
506 } else {
507 None
508 };
509 let constants = if layout.constants().is_empty() {
510 None
511 } else {
512 let constants_array_type =
513 real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
514 let constants = module.add_global(
515 constants_array_type,
516 Some(AddressSpace::default()),
517 "enzyme_const_constants",
518 );
519 constants.set_visibility(GlobalVisibility::Hidden);
520 constants.set_constant(false);
521 constants.set_initializer(&constants_array_type.const_zero());
522 Some(constants)
523 };
524 let indices = if layout.indices().is_empty() {
525 None
526 } else {
527 let indices_array_type =
528 int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
529 let indices_array_values = layout
530 .indices()
531 .iter()
532 .map(|&i| int_type.const_int(i.try_into().unwrap(), false))
533 .collect::<Vec<IntValue>>();
534 let indices_value = int_type.const_array(indices_array_values.as_slice());
535 let indices = module.add_global(
536 indices_array_type,
537 Some(AddressSpace::default()),
538 "enzyme_const_indices",
539 );
540 indices.set_constant(true);
541 indices.set_visibility(GlobalVisibility::Hidden);
542 indices.set_initializer(&indices_value);
543 Some(indices)
544 };
545 Self {
546 indices,
547 thread_counter,
548 constants,
549 }
550 }
551}
552
553pub enum CompileGradientArgType {
554 Const,
555 Dup,
556 DupNoNeed,
557}
558
559pub enum CompileMode {
560 Forward,
561 ForwardSens,
562 Reverse,
563 ReverseSens,
564}
565
566pub struct CodeGen<'ctx> {
567 context: &'ctx inkwell::context::Context,
568 module: Module<'ctx>,
569 builder: Builder<'ctx>,
570 variables: HashMap<String, PointerValue<'ctx>>,
571 functions: HashMap<String, FunctionValue<'ctx>>,
572 fn_value_opt: Option<FunctionValue<'ctx>>,
573 tensor_ptr_opt: Option<PointerValue<'ctx>>,
574 real_type: FloatType<'ctx>,
575 real_ptr_type: PointerType<'ctx>,
576 real_type_str: String,
577 int_type: IntType<'ctx>,
578 int_ptr_type: PointerType<'ctx>,
579 layout: DataLayout,
580 globals: Globals<'ctx>,
581 threaded: bool,
582 _ee: Option<ExecutionEngine<'ctx>>,
583}
584
585unsafe extern "C" fn fwd_handler(
586 _builder: LLVMBuilderRef,
587 _call_instruction: LLVMValueRef,
588 _gutils: *mut GradientUtils,
589 _dcall: *mut LLVMValueRef,
590 _normal_return: *mut LLVMValueRef,
591 _shadow_return: *mut LLVMValueRef,
592) -> u8 {
593 1
594}
595
596unsafe extern "C" fn rev_handler(
597 builder: LLVMBuilderRef,
598 call_instruction: LLVMValueRef,
599 gutils: *mut DiffeGradientUtils,
600 _tape: LLVMValueRef,
601) {
602 let call_block = LLVMGetInstructionParent(call_instruction);
603 let call_function = LLVMGetBasicBlockParent(call_block);
604 let module = LLVMGetGlobalParent(call_function);
605 let name_c_str = CString::new("barrier_grad").unwrap();
606 let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
607 let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
608 let barrier_num = LLVMGetArgOperand(call_instruction, 0);
609 let total_barriers = LLVMGetArgOperand(call_instruction, 1);
610 let thread_count = LLVMGetArgOperand(call_instruction, 2);
611 let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
612 let total_barriers =
613 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
614 let thread_count =
615 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
616 let mut args = [barrier_num, total_barriers, thread_count];
617 let name_c_str = CString::new("").unwrap();
618 LLVMBuildCall2(
619 builder,
620 barrier_func_type,
621 barrier_func,
622 args.as_mut_ptr(),
623 args.len() as u32,
624 name_c_str.as_ptr(),
625 );
626}
627
628#[allow(dead_code)]
629enum PrintValue<'ctx> {
630 Real(FloatValue<'ctx>),
631 Int(IntValue<'ctx>),
632}
633
634impl<'ctx> CodeGen<'ctx> {
635 pub fn new(
636 model: &DiscreteModel,
637 context: &'ctx inkwell::context::Context,
638 real_type: FloatType<'ctx>,
639 int_type: IntType<'ctx>,
640 real_type_str: &str,
641 threaded: bool,
642 ) -> Result<Self> {
643 let builder = context.create_builder();
644 let layout = DataLayout::new(model);
645 let module = context.create_module(model.name());
646 let globals = Globals::new(&layout, &module, int_type, real_type, threaded);
647 let real_ptr_type = Self::pointer_type(context, real_type.into());
648 let int_ptr_type = Self::pointer_type(context, int_type.into());
649 let mut ret = Self {
650 context,
651 module,
652 builder,
653 real_type,
654 real_ptr_type,
655 real_type_str: real_type_str.to_owned(),
656 variables: HashMap::new(),
657 functions: HashMap::new(),
658 fn_value_opt: None,
659 tensor_ptr_opt: None,
660 layout,
661 int_type,
662 int_ptr_type,
663 globals,
664 threaded,
665 _ee: None,
666 };
667 if threaded {
668 ret.compile_barrier_init()?;
669 ret.compile_barrier()?;
670 ret.compile_barrier_grad()?;
671 }
674 Ok(ret)
675 }
676
677 #[allow(dead_code)]
678 fn compile_print_value(
679 &mut self,
680 name: &str,
681 value: PrintValue<'ctx>,
682 ) -> Result<CallSiteValue> {
683 let void_type = self.context.void_type();
684 let printf_type = void_type.fn_type(&[self.int_ptr_type.into()], true);
686 let printf = match self.module.get_function("printf") {
688 Some(f) => f,
689 None => self
690 .module
691 .add_function("printf", printf_type, Some(Linkage::External)),
692 };
693 let (format_str, format_str_name) = match value {
694 PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
695 PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
696 };
697 let format_str = CString::new(format_str).unwrap();
699 let format_str_global = match self.module.get_global(format_str_name.as_str()) {
701 Some(g) => g,
702 None => {
703 let format_str = self.context.const_string(format_str.as_bytes(), true);
704 let fmt_str =
705 self.module
706 .add_global(format_str.get_type(), None, format_str_name.as_str());
707 fmt_str.set_initializer(&format_str);
708 fmt_str.set_visibility(GlobalVisibility::Hidden);
709 fmt_str
710 }
711 };
712 let format_str_ptr = self.builder.build_pointer_cast(
714 format_str_global.as_pointer_value(),
715 self.int_ptr_type,
716 "format_str_ptr",
717 )?;
718 let value: BasicMetadataValueEnum = match value {
719 PrintValue::Real(v) => v.into(),
720 PrintValue::Int(v) => v.into(),
721 };
722 self.builder
723 .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
724 .map_err(|e| anyhow!("Error building call to printf: {}", e))
725 }
726
727 fn compile_set_constants(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
728 self.clear();
729 let void_type = self.context.void_type();
730 let fn_type = void_type.fn_type(&[self.int_type.into(), self.int_type.into()], false);
731 let fn_arg_names = &["thread_id", "thread_dim"];
732 let function = self.module.add_function("set_constants", fn_type, None);
733
734 let basic_block = self.context.append_basic_block(function, "entry");
735 self.fn_value_opt = Some(function);
736 self.builder.position_at_end(basic_block);
737
738 for (i, arg) in function.get_param_iter().enumerate() {
739 let name = fn_arg_names[i];
740 let alloca = self.function_arg_alloca(name, arg);
741 self.insert_param(name, alloca);
742 }
743
744 self.insert_indices();
745 self.insert_constants(model);
746
747 let mut nbarriers = 0;
748 let total_barriers = (model.constant_defns().len()) as u64;
749 #[allow(clippy::explicit_counter_loop)]
750 for a in model.constant_defns() {
751 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
752 self.jit_compile_call_barrier(nbarriers, total_barriers);
753 nbarriers += 1;
754 }
755
756 self.builder.build_return(None)?;
757
758 if function.verify(true) {
759 Ok(function)
760 } else {
761 function.print_to_stderr();
762 unsafe {
763 function.delete();
764 }
765 Err(anyhow!("Invalid generated function."))
766 }
767 }
768
769 fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
770 self.clear();
771 let void_type = self.context.void_type();
772 let fn_type = void_type.fn_type(&[], false);
773 let function = self.module.add_function("barrier_init", fn_type, None);
774
775 let entry_block = self.context.append_basic_block(function, "entry");
776
777 self.fn_value_opt = Some(function);
778 self.builder.position_at_end(entry_block);
779
780 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
781 self.builder
782 .build_store(thread_counter, self.int_type.const_zero())?;
783
784 self.builder.build_return(None)?;
785
786 if function.verify(true) {
787 self.functions.insert("barrier_init".to_owned(), function);
788 Ok(function)
789 } else {
790 function.print_to_stderr();
791 unsafe {
792 function.delete();
793 }
794 Err(anyhow!("Invalid generated function."))
795 }
796 }
797
798 fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
799 self.clear();
800 let void_type = self.context.void_type();
801 let fn_type = void_type.fn_type(
802 &[
803 self.int_type.into(),
804 self.int_type.into(),
805 self.int_type.into(),
806 ],
807 false,
808 );
809 let function = self.module.add_function("barrier", fn_type, None);
810 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
811 let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
812 function.add_attribute(AttributeLoc::Function, noinline);
813
814 let entry_block = self.context.append_basic_block(function, "entry");
815 let increment_block = self.context.append_basic_block(function, "increment");
816 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
817 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
818
819 self.fn_value_opt = Some(function);
820 self.builder.position_at_end(entry_block);
821
822 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
823 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
824 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
825 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
826
827 let nbarrier_equals_total_barriers = self
828 .builder
829 .build_int_compare(
830 IntPredicate::EQ,
831 barrier_num,
832 total_barriers,
833 "nbarrier_equals_total_barriers",
834 )
835 .unwrap();
836 self.builder.build_conditional_branch(
838 nbarrier_equals_total_barriers,
839 barrier_done_block,
840 increment_block,
841 )?;
842 self.builder.position_at_end(increment_block);
843
844 let barrier_num_times_thread_count = self
845 .builder
846 .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
847 .unwrap();
848
849 let i32_type = self.context.i32_type();
851 let one = i32_type.const_int(1, false);
852 self.builder.build_atomicrmw(
853 AtomicRMWBinOp::Add,
854 thread_counter,
855 one,
856 AtomicOrdering::Monotonic,
857 )?;
858
859 self.builder.build_unconditional_branch(wait_loop_block)?;
861 self.builder.position_at_end(wait_loop_block);
862
863 let current_value = self
864 .builder
865 .build_load(i32_type, thread_counter, "current_value")?
866 .into_int_value();
867
868 current_value
869 .as_instruction_value()
870 .unwrap()
871 .set_atomic_ordering(AtomicOrdering::Monotonic)
872 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
873
874 let all_threads_done = self.builder.build_int_compare(
875 IntPredicate::UGE,
876 current_value,
877 barrier_num_times_thread_count,
878 "all_threads_done",
879 )?;
880
881 self.builder.build_conditional_branch(
882 all_threads_done,
883 barrier_done_block,
884 wait_loop_block,
885 )?;
886 self.builder.position_at_end(barrier_done_block);
887
888 self.builder.build_return(None)?;
889
890 if function.verify(true) {
891 self.functions.insert("barrier".to_owned(), function);
892 Ok(function)
893 } else {
894 function.print_to_stderr();
895 unsafe {
896 function.delete();
897 }
898 Err(anyhow!("Invalid generated function."))
899 }
900 }
901
902 fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
903 self.clear();
904 let void_type = self.context.void_type();
905 let fn_type = void_type.fn_type(
906 &[
907 self.int_type.into(),
908 self.int_type.into(),
909 self.int_type.into(),
910 ],
911 false,
912 );
913 let function = self.module.add_function("barrier_grad", fn_type, None);
914
915 let entry_block = self.context.append_basic_block(function, "entry");
916 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
917 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
918
919 self.fn_value_opt = Some(function);
920 self.builder.position_at_end(entry_block);
921
922 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
923 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
924 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
925 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
926
927 let twice_total_barriers = self
928 .builder
929 .build_int_mul(
930 total_barriers,
931 self.int_type.const_int(2, false),
932 "twice_total_barriers",
933 )
934 .unwrap();
935 let twice_total_barriers_minus_barrier_num = self
936 .builder
937 .build_int_sub(
938 twice_total_barriers,
939 barrier_num,
940 "twice_total_barriers_minus_barrier_num",
941 )
942 .unwrap();
943 let twice_total_barriers_minus_barrier_num_times_thread_count = self
944 .builder
945 .build_int_mul(
946 twice_total_barriers_minus_barrier_num,
947 thread_count,
948 "twice_total_barriers_minus_barrier_num_times_thread_count",
949 )
950 .unwrap();
951
952 let i32_type = self.context.i32_type();
954 let one = i32_type.const_int(1, false);
955 self.builder.build_atomicrmw(
956 AtomicRMWBinOp::Add,
957 thread_counter,
958 one,
959 AtomicOrdering::Monotonic,
960 )?;
961
962 self.builder.build_unconditional_branch(wait_loop_block)?;
964 self.builder.position_at_end(wait_loop_block);
965
966 let current_value = self
967 .builder
968 .build_load(i32_type, thread_counter, "current_value")?
969 .into_int_value();
970 current_value
971 .as_instruction_value()
972 .unwrap()
973 .set_atomic_ordering(AtomicOrdering::Monotonic)
974 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
975
976 let all_threads_done = self.builder.build_int_compare(
977 IntPredicate::UGE,
978 current_value,
979 twice_total_barriers_minus_barrier_num_times_thread_count,
980 "all_threads_done",
981 )?;
982
983 self.builder.build_conditional_branch(
984 all_threads_done,
985 barrier_done_block,
986 wait_loop_block,
987 )?;
988 self.builder.position_at_end(barrier_done_block);
989
990 self.builder.build_return(None)?;
991
992 if function.verify(true) {
993 self.functions.insert("barrier_grad".to_owned(), function);
994 Ok(function)
995 } else {
996 function.print_to_stderr();
997 unsafe {
998 function.delete();
999 }
1000 Err(anyhow!("Invalid generated function."))
1001 }
1002 }
1003
1004 fn jit_compile_call_barrier(&mut self, nbarrier: u64, total_barriers: u64) {
1005 if !self.threaded {
1006 return;
1007 }
1008 let thread_dim = self.get_param("thread_dim");
1009 let thread_dim = self
1010 .builder
1011 .build_load(self.int_type, *thread_dim, "thread_dim")
1012 .unwrap()
1013 .into_int_value();
1014 let nbarrier = self.int_type.const_int(nbarrier + 1, false);
1015 let total_barriers = self.int_type.const_int(total_barriers, false);
1016 let barrier = self.get_function("barrier").unwrap();
1017 self.builder
1018 .build_call(
1019 barrier,
1020 &[
1021 BasicMetadataValueEnum::IntValue(nbarrier),
1022 BasicMetadataValueEnum::IntValue(total_barriers),
1023 BasicMetadataValueEnum::IntValue(thread_dim),
1024 ],
1025 "barrier",
1026 )
1027 .unwrap();
1028 }
1029
1030 fn jit_threading_limits(
1031 &mut self,
1032 size: IntValue<'ctx>,
1033 ) -> Result<(
1034 IntValue<'ctx>,
1035 IntValue<'ctx>,
1036 BasicBlock<'ctx>,
1037 BasicBlock<'ctx>,
1038 )> {
1039 let one = self.int_type.const_int(1, false);
1040 let thread_id = self.get_param("thread_id");
1041 let thread_id = self
1042 .builder
1043 .build_load(self.int_type, *thread_id, "thread_id")
1044 .unwrap()
1045 .into_int_value();
1046 let thread_dim = self.get_param("thread_dim");
1047 let thread_dim = self
1048 .builder
1049 .build_load(self.int_type, *thread_dim, "thread_dim")
1050 .unwrap()
1051 .into_int_value();
1052
1053 let i_times_size = self
1055 .builder
1056 .build_int_mul(thread_id, size, "i_times_size")?;
1057 let start = self
1058 .builder
1059 .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1060
1061 let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1063 let i_plus_one_times_size =
1064 self.builder
1065 .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1066 let end = self
1067 .builder
1068 .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1069
1070 let test_done = self.builder.get_insert_block().unwrap();
1071 let next_block = self
1072 .context
1073 .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1074 self.builder.position_at_end(next_block);
1075
1076 Ok((start, end, test_done, next_block))
1077 }
1078
1079 fn jit_end_threading(
1080 &mut self,
1081 start: IntValue<'ctx>,
1082 end: IntValue<'ctx>,
1083 test_done: BasicBlock<'ctx>,
1084 next: BasicBlock<'ctx>,
1085 ) -> Result<()> {
1086 let exit = self
1087 .context
1088 .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1089 self.builder.build_unconditional_branch(exit)?;
1090 self.builder.position_at_end(test_done);
1091 let done = self
1093 .builder
1094 .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1095 self.builder.build_conditional_branch(done, exit, next)?;
1096 self.builder.position_at_end(exit);
1097 Ok(())
1098 }
1099
1100 pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1101 self.module.write_bitcode_to_path(path);
1102 }
1103
1104 fn insert_constants(&mut self, model: &DiscreteModel) {
1105 if let Some(constants) = self.globals.constants.as_ref() {
1106 self.insert_param("constants", constants.as_pointer_value());
1107 for tensor in model.constant_defns() {
1108 self.insert_tensor(tensor, true);
1109 }
1110 }
1111 }
1112
1113 fn insert_data(&mut self, model: &DiscreteModel) {
1114 self.insert_constants(model);
1115
1116 for tensor in model.inputs() {
1117 self.insert_tensor(tensor, false);
1118 }
1119 for tensor in model.input_dep_defns() {
1120 self.insert_tensor(tensor, false);
1121 }
1122 for tensor in model.time_dep_defns() {
1123 self.insert_tensor(tensor, false);
1124 }
1125 for tensor in model.state_dep_defns() {
1126 self.insert_tensor(tensor, false);
1127 }
1128 }
1129
1130 fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1131 context.ptr_type(AddressSpace::default())
1132 }
1133
1134 fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1135 context.ptr_type(AddressSpace::default())
1136 }
1137
1138 fn insert_indices(&mut self) {
1139 if let Some(indices) = self.globals.indices.as_ref() {
1140 let i32_type = self.context.i32_type();
1141 let zero = i32_type.const_int(0, false);
1142 let ptr = unsafe {
1143 indices
1144 .as_pointer_value()
1145 .const_in_bounds_gep(i32_type, &[zero])
1146 };
1147 self.variables.insert("indices".to_owned(), ptr);
1148 }
1149 }
1150
1151 fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1152 self.variables.insert(name.to_owned(), value);
1153 }
1154
1155 fn build_gep<T: BasicType<'ctx>>(
1156 &self,
1157 ty: T,
1158 ptr: PointerValue<'ctx>,
1159 ordered_indexes: &[IntValue<'ctx>],
1160 name: &str,
1161 ) -> Result<PointerValue<'ctx>> {
1162 unsafe {
1163 self.builder
1164 .build_gep(ty, ptr, ordered_indexes, name)
1165 .map_err(|e| e.into())
1166 }
1167 }
1168
1169 fn build_load<T: BasicType<'ctx>>(
1170 &self,
1171 ty: T,
1172 ptr: PointerValue<'ctx>,
1173 name: &str,
1174 ) -> Result<BasicValueEnum<'ctx>> {
1175 self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1176 }
1177
1178 fn get_ptr_to_index<T: BasicType<'ctx>>(
1179 builder: &Builder<'ctx>,
1180 ty: T,
1181 ptr: &PointerValue<'ctx>,
1182 index: IntValue<'ctx>,
1183 name: &str,
1184 ) -> PointerValue<'ctx> {
1185 unsafe {
1186 builder
1187 .build_in_bounds_gep(ty, *ptr, &[index], name)
1188 .unwrap()
1189 }
1190 }
1191
1192 fn insert_state(&mut self, u: &Tensor) {
1193 let mut data_index = 0;
1194 for blk in u.elmts() {
1195 if let Some(name) = blk.name() {
1196 let ptr = self.variables.get("u").unwrap();
1197 let i = self
1198 .context
1199 .i32_type()
1200 .const_int(data_index.try_into().unwrap(), false);
1201 let alloca = Self::get_ptr_to_index(
1202 &self.create_entry_block_builder(),
1203 self.real_type,
1204 ptr,
1205 i,
1206 blk.name().unwrap(),
1207 );
1208 self.variables.insert(name.to_owned(), alloca);
1209 }
1210 data_index += blk.nnz();
1211 }
1212 }
1213 fn insert_dot_state(&mut self, dudt: &Tensor) {
1214 let mut data_index = 0;
1215 for blk in dudt.elmts() {
1216 if let Some(name) = blk.name() {
1217 let ptr = self.variables.get("dudt").unwrap();
1218 let i = self
1219 .context
1220 .i32_type()
1221 .const_int(data_index.try_into().unwrap(), false);
1222 let alloca = Self::get_ptr_to_index(
1223 &self.create_entry_block_builder(),
1224 self.real_type,
1225 ptr,
1226 i,
1227 blk.name().unwrap(),
1228 );
1229 self.variables.insert(name.to_owned(), alloca);
1230 }
1231 data_index += blk.nnz();
1232 }
1233 }
1234 fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1235 let var_name = if is_constant { "constants" } else { "data" };
1236 let ptr = *self.variables.get(var_name).unwrap();
1237 let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1238 let i = self
1239 .context
1240 .i32_type()
1241 .const_int(data_index.try_into().unwrap(), false);
1242 let alloca = Self::get_ptr_to_index(
1243 &self.create_entry_block_builder(),
1244 self.real_type,
1245 &ptr,
1246 i,
1247 tensor.name(),
1248 );
1249 self.variables.insert(tensor.name().to_owned(), alloca);
1250
1251 for blk in tensor.elmts() {
1253 if let Some(name) = blk.name() {
1254 let i = self
1255 .context
1256 .i32_type()
1257 .const_int(data_index.try_into().unwrap(), false);
1258 let alloca = Self::get_ptr_to_index(
1259 &self.create_entry_block_builder(),
1260 self.real_type,
1261 &ptr,
1262 i,
1263 name,
1264 );
1265 self.variables.insert(name.to_owned(), alloca);
1266 }
1267 data_index += blk.nnz();
1269 }
1270 }
1271 fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1272 self.variables.get(name).unwrap()
1273 }
1274
1275 fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1276 self.variables.get(tensor.name()).unwrap()
1277 }
1278
1279 fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1280 match self.functions.get(name) {
1281 Some(&func) => Some(func),
1282 None => {
1283 let function = match name {
1284 "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs"
1286 | "copysign" | "pow" | "min" | "max" => {
1287 let arg_len = 1;
1288 let intrinsic_name = match name {
1289 "min" => "minnum",
1290 "max" => "maxnum",
1291 "abs" => "fabs",
1292 _ => name,
1293 };
1294 let llvm_name = format!("llvm.{}.{}", intrinsic_name, self.real_type_str);
1295 let intrinsic = Intrinsic::find(&llvm_name).unwrap();
1296 let ret_type = self.real_type;
1297
1298 let args_types = std::iter::repeat_n(ret_type, arg_len)
1299 .map(|f| f.into())
1300 .collect::<Vec<BasicTypeEnum>>();
1301 return intrinsic.get_declaration(&self.module, args_types.as_slice());
1303 }
1304 "sigmoid" => {
1306 let arg_len = 1;
1307 let ret_type = self.real_type;
1308
1309 let args_types = std::iter::repeat_n(ret_type, arg_len)
1310 .map(|f| f.into())
1311 .collect::<Vec<BasicMetadataTypeEnum>>();
1312
1313 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1314 let fn_val = self.module.add_function(name, fn_type, None);
1315
1316 for arg in fn_val.get_param_iter() {
1317 arg.into_float_value().set_name("x");
1318 }
1319
1320 let current_block = self.builder.get_insert_block().unwrap();
1321 let basic_block = self.context.append_basic_block(fn_val, "entry");
1322 self.builder.position_at_end(basic_block);
1323 let x = fn_val.get_nth_param(0)?.into_float_value();
1324 let one = self.real_type.const_float(1.0);
1325 let negx = self.builder.build_float_neg(x, name).ok()?;
1326 let exp = self.get_function("exp").unwrap();
1327 let exp_negx = self
1328 .builder
1329 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1330 .ok()?;
1331 let one_plus_exp_negx = self
1332 .builder
1333 .build_float_add(
1334 exp_negx
1335 .try_as_basic_value()
1336 .left()
1337 .unwrap()
1338 .into_float_value(),
1339 one,
1340 name,
1341 )
1342 .ok()?;
1343 let sigmoid = self
1344 .builder
1345 .build_float_div(one, one_plus_exp_negx, name)
1346 .ok()?;
1347 self.builder.build_return(Some(&sigmoid)).ok();
1348 self.builder.position_at_end(current_block);
1349 Some(fn_val)
1350 }
1351 "arcsinh" | "arccosh" => {
1352 let arg_len = 1;
1353 let ret_type = self.real_type;
1354
1355 let args_types = std::iter::repeat_n(ret_type, arg_len)
1356 .map(|f| f.into())
1357 .collect::<Vec<BasicMetadataTypeEnum>>();
1358
1359 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1360 let fn_val = self.module.add_function(name, fn_type, None);
1361
1362 for arg in fn_val.get_param_iter() {
1363 arg.into_float_value().set_name("x");
1364 }
1365
1366 let current_block = self.builder.get_insert_block().unwrap();
1367 let basic_block = self.context.append_basic_block(fn_val, "entry");
1368 self.builder.position_at_end(basic_block);
1369 let x = fn_val.get_nth_param(0)?.into_float_value();
1370 let one = match name {
1371 "arccosh" => self.real_type.const_float(-1.0),
1372 "arcsinh" => self.real_type.const_float(1.0),
1373 _ => panic!("unknown function"),
1374 };
1375 let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1376 let one_plus_x_squared =
1377 self.builder.build_float_add(x_squared, one, name).ok()?;
1378 let sqrt = self.get_function("sqrt").unwrap();
1379 let sqrt_one_plus_x_squared = self
1380 .builder
1381 .build_call(
1382 sqrt,
1383 &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1384 name,
1385 )
1386 .unwrap()
1387 .try_as_basic_value()
1388 .left()
1389 .unwrap()
1390 .into_float_value();
1391 let x_plus_sqrt_one_plus_x_squared = self
1392 .builder
1393 .build_float_add(x, sqrt_one_plus_x_squared, name)
1394 .ok()?;
1395 let ln = self.get_function("log").unwrap();
1396 let result = self
1397 .builder
1398 .build_call(
1399 ln,
1400 &[BasicMetadataValueEnum::FloatValue(
1401 x_plus_sqrt_one_plus_x_squared,
1402 )],
1403 name,
1404 )
1405 .unwrap()
1406 .try_as_basic_value()
1407 .left()
1408 .unwrap()
1409 .into_float_value();
1410 self.builder.build_return(Some(&result)).ok();
1411 self.builder.position_at_end(current_block);
1412 Some(fn_val)
1413 }
1414 "heaviside" => {
1415 let arg_len = 1;
1416 let ret_type = self.real_type;
1417
1418 let args_types = std::iter::repeat_n(ret_type, arg_len)
1419 .map(|f| f.into())
1420 .collect::<Vec<BasicMetadataTypeEnum>>();
1421
1422 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1423 let fn_val = self.module.add_function(name, fn_type, None);
1424
1425 for arg in fn_val.get_param_iter() {
1426 arg.into_float_value().set_name("x");
1427 }
1428
1429 let current_block = self.builder.get_insert_block().unwrap();
1430 let basic_block = self.context.append_basic_block(fn_val, "entry");
1431 self.builder.position_at_end(basic_block);
1432 let x = fn_val.get_nth_param(0)?.into_float_value();
1433 let zero = self.real_type.const_float(0.0);
1434 let one = self.real_type.const_float(1.0);
1435 let result = self
1436 .builder
1437 .build_select(
1438 self.builder
1439 .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1440 .unwrap(),
1441 one,
1442 zero,
1443 name,
1444 )
1445 .ok()?;
1446 self.builder.build_return(Some(&result)).ok();
1447 self.builder.position_at_end(current_block);
1448 Some(fn_val)
1449 }
1450 "tanh" | "sinh" | "cosh" => {
1451 let arg_len = 1;
1452 let ret_type = self.real_type;
1453
1454 let args_types = std::iter::repeat_n(ret_type, arg_len)
1455 .map(|f| f.into())
1456 .collect::<Vec<BasicMetadataTypeEnum>>();
1457
1458 let fn_type = ret_type.fn_type(args_types.as_slice(), false);
1459 let fn_val = self.module.add_function(name, fn_type, None);
1460
1461 for arg in fn_val.get_param_iter() {
1462 arg.into_float_value().set_name("x");
1463 }
1464
1465 let current_block = self.builder.get_insert_block().unwrap();
1466 let basic_block = self.context.append_basic_block(fn_val, "entry");
1467 self.builder.position_at_end(basic_block);
1468 let x = fn_val.get_nth_param(0)?.into_float_value();
1469 let negx = self.builder.build_float_neg(x, name).ok()?;
1470 let exp = self.get_function("exp").unwrap();
1471 let exp_negx = self
1472 .builder
1473 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1474 .ok()?;
1475 let expx = self
1476 .builder
1477 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1478 .ok()?;
1479 let expx_minus_exp_negx = self
1480 .builder
1481 .build_float_sub(
1482 expx.try_as_basic_value().left().unwrap().into_float_value(),
1483 exp_negx
1484 .try_as_basic_value()
1485 .left()
1486 .unwrap()
1487 .into_float_value(),
1488 name,
1489 )
1490 .ok()?;
1491 let expx_plus_exp_negx = self
1492 .builder
1493 .build_float_add(
1494 expx.try_as_basic_value().left().unwrap().into_float_value(),
1495 exp_negx
1496 .try_as_basic_value()
1497 .left()
1498 .unwrap()
1499 .into_float_value(),
1500 name,
1501 )
1502 .ok()?;
1503 let result = match name {
1504 "tanh" => self
1505 .builder
1506 .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1507 .ok()?,
1508 "sinh" => self
1509 .builder
1510 .build_float_div(
1511 expx_minus_exp_negx,
1512 self.real_type.const_float(2.0),
1513 name,
1514 )
1515 .ok()?,
1516 "cosh" => self
1517 .builder
1518 .build_float_div(
1519 expx_plus_exp_negx,
1520 self.real_type.const_float(2.0),
1521 name,
1522 )
1523 .ok()?,
1524 _ => panic!("unknown function"),
1525 };
1526 self.builder.build_return(Some(&result)).ok();
1527 self.builder.position_at_end(current_block);
1528 Some(fn_val)
1529 }
1530 _ => None,
1531 }?;
1532 self.functions.insert(name.to_owned(), function);
1533 Some(function)
1534 }
1535 }
1536 }
1537 #[inline]
1539 fn fn_value(&self) -> FunctionValue<'ctx> {
1540 self.fn_value_opt.unwrap()
1541 }
1542
1543 #[inline]
1544 fn tensor_ptr(&self) -> PointerValue<'ctx> {
1545 self.tensor_ptr_opt.unwrap()
1546 }
1547
1548 fn create_entry_block_builder(&self) -> Builder<'ctx> {
1550 let builder = self.context.create_builder();
1551 let entry = self.fn_value().get_first_basic_block().unwrap();
1552 match entry.get_first_instruction() {
1553 Some(first_instr) => builder.position_before(&first_instr),
1554 None => builder.position_at_end(entry),
1555 }
1556 builder
1557 }
1558
1559 fn jit_compile_scalar(
1560 &mut self,
1561 a: &Tensor,
1562 res_ptr_opt: Option<PointerValue<'ctx>>,
1563 ) -> Result<PointerValue<'ctx>> {
1564 let res_type = self.real_type;
1565 let res_ptr = match res_ptr_opt {
1566 Some(ptr) => ptr,
1567 None => self
1568 .create_entry_block_builder()
1569 .build_alloca(res_type, a.name())?,
1570 };
1571 let name = a.name();
1572 let elmt = a.elmts().first().unwrap();
1573
1574 let curr_block = self.builder.get_insert_block().unwrap();
1576 let mut next_block_opt = None;
1577 if self.threaded {
1578 let next_block = self.context.append_basic_block(self.fn_value(), "next");
1579 self.builder.position_at_end(next_block);
1580 next_block_opt = Some(next_block);
1581 }
1582
1583 let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, None)?;
1584 self.builder.build_store(res_ptr, float_value)?;
1585
1586 if self.threaded {
1588 let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
1589 self.builder.build_unconditional_branch(exit_block)?;
1590 self.builder.position_at_end(curr_block);
1591
1592 let thread_id = self.get_param("thread_id");
1593 let thread_id = self
1594 .builder
1595 .build_load(self.int_type, *thread_id, "thread_id")
1596 .unwrap()
1597 .into_int_value();
1598 let is_first_thread = self.builder.build_int_compare(
1599 IntPredicate::EQ,
1600 thread_id,
1601 self.int_type.const_zero(),
1602 "is_first_thread",
1603 )?;
1604 self.builder.build_conditional_branch(
1605 is_first_thread,
1606 next_block_opt.unwrap(),
1607 exit_block,
1608 )?;
1609
1610 self.builder.position_at_end(exit_block);
1611 }
1612
1613 Ok(res_ptr)
1614 }
1615
1616 fn jit_compile_tensor(
1617 &mut self,
1618 a: &Tensor,
1619 res_ptr_opt: Option<PointerValue<'ctx>>,
1620 ) -> Result<PointerValue<'ctx>> {
1621 if a.rank() == 0 {
1623 return self.jit_compile_scalar(a, res_ptr_opt);
1624 }
1625
1626 let res_type = self.real_type;
1627 let res_ptr = match res_ptr_opt {
1628 Some(ptr) => ptr,
1629 None => self
1630 .create_entry_block_builder()
1631 .build_alloca(res_type, a.name())?,
1632 };
1633
1634 self.tensor_ptr_opt = Some(res_ptr);
1636
1637 for (i, blk) in a.elmts().iter().enumerate() {
1638 let default = format!("{}-{}", a.name(), i);
1639 let name = blk.name().unwrap_or(default.as_str());
1640 self.jit_compile_block(name, a, blk)?;
1641 }
1642 Ok(res_ptr)
1643 }
1644
1645 fn jit_compile_block(&mut self, name: &str, tensor: &Tensor, elmt: &TensorBlock) -> Result<()> {
1646 let translation = Translation::new(
1647 elmt.expr_layout(),
1648 elmt.layout(),
1649 elmt.start(),
1650 tensor.layout_ptr(),
1651 );
1652
1653 if elmt.expr_layout().is_dense() {
1654 self.jit_compile_dense_block(name, elmt, &translation)
1655 } else if elmt.expr_layout().is_diagonal() {
1656 self.jit_compile_diagonal_block(name, elmt, &translation)
1657 } else if elmt.expr_layout().is_sparse() {
1658 match translation.source {
1659 TranslationFrom::SparseContraction { .. } => {
1660 self.jit_compile_sparse_contraction_block(name, elmt, &translation)
1661 }
1662 _ => self.jit_compile_sparse_block(name, elmt, &translation),
1663 }
1664 } else {
1665 return Err(anyhow!(
1666 "unsupported block layout: {:?}",
1667 elmt.expr_layout()
1668 ));
1669 }
1670 }
1671
1672 fn jit_compile_dense_block(
1674 &mut self,
1675 name: &str,
1676 elmt: &TensorBlock,
1677 translation: &Translation,
1678 ) -> Result<()> {
1679 let int_type = self.int_type;
1680
1681 let mut preblock = self.builder.get_insert_block().unwrap();
1682 let expr_rank = elmt.expr_layout().rank();
1683 let expr_shape = elmt
1684 .expr_layout()
1685 .shape()
1686 .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
1687 let one = int_type.const_int(1, false);
1688
1689 let mut expr_strides = vec![1; expr_rank];
1690 if expr_rank > 0 {
1691 for i in (0..expr_rank - 1).rev() {
1692 expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1693 }
1694 }
1695 let expr_strides = expr_strides
1696 .iter()
1697 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1698 .collect::<Vec<IntValue>>();
1699
1700 let mut indices = Vec::new();
1702 let mut blocks = Vec::new();
1703
1704 let (contract_sum, contract_by, contract_strides) =
1706 if let TranslationFrom::DenseContraction {
1707 contract_by,
1708 contract_len: _,
1709 } = translation.source
1710 {
1711 let contract_rank = expr_rank - contract_by;
1712 let mut contract_strides = vec![1; contract_rank];
1713 for i in (0..contract_rank - 1).rev() {
1714 contract_strides[i] =
1715 contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
1716 }
1717 let contract_strides = contract_strides
1718 .iter()
1719 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
1720 .collect::<Vec<IntValue>>();
1721 (
1722 Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
1723 contract_by,
1724 Some(contract_strides),
1725 )
1726 } else {
1727 (None, 0, None)
1728 };
1729
1730 let (thread_start, thread_end, test_done, next) = if self.threaded {
1732 let (start, end, test_done, next) =
1733 self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
1734 preblock = next;
1735 (Some(start), Some(end), Some(test_done), Some(next))
1736 } else {
1737 (None, None, None, None)
1738 };
1739
1740 for i in 0..expr_rank {
1741 let block = self.context.append_basic_block(self.fn_value(), name);
1742 self.builder.build_unconditional_branch(block)?;
1743 self.builder.position_at_end(block);
1744
1745 let start_index = if i == 0 && self.threaded {
1746 thread_start.unwrap()
1747 } else {
1748 self.int_type.const_zero()
1749 };
1750
1751 let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
1752 curr_index.add_incoming(&[(&start_index, preblock)]);
1753
1754 if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1755 self.builder
1756 .build_store(contract_sum.unwrap(), self.real_type.const_zero())?;
1757 }
1758
1759 indices.push(curr_index);
1760 blocks.push(block);
1761 preblock = block;
1762 }
1763
1764 let indices_int: Vec<IntValue> = indices
1765 .iter()
1766 .map(|i| i.as_basic_value().into_int_value())
1767 .collect();
1768
1769 let float_value =
1770 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, None)?;
1771
1772 if contract_sum.is_some() {
1773 let contract_sum_value = self
1774 .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1775 .into_float_value();
1776 let new_contract_sum_value = self.builder.build_float_add(
1777 contract_sum_value,
1778 float_value,
1779 "new_contract_sum",
1780 )?;
1781 self.builder
1782 .build_store(contract_sum.unwrap(), new_contract_sum_value)?;
1783 } else {
1784 let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
1785 self.int_type.const_zero(),
1786 |acc, (i, s)| {
1787 let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
1788 self.builder.build_int_add(acc, tmp, "acc").unwrap()
1789 },
1790 );
1791 preblock = self.jit_compile_broadcast_and_store(
1792 name,
1793 elmt,
1794 float_value,
1795 expr_index,
1796 translation,
1797 preblock,
1798 )?;
1799 }
1800
1801 for i in (0..expr_rank).rev() {
1803 let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
1805 indices[i].add_incoming(&[(&next_index, preblock)]);
1806
1807 if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
1808 let contract_sum_value = self
1809 .build_load(self.real_type, contract_sum.unwrap(), "contract_sum")?
1810 .into_float_value();
1811 let contract_strides = contract_strides.as_ref().unwrap();
1812 let elmt_index = indices_int
1813 .iter()
1814 .take(contract_strides.len())
1815 .zip(contract_strides.iter())
1816 .fold(self.int_type.const_zero(), |acc, (i, s)| {
1817 let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
1818 self.builder.build_int_add(acc, tmp, "acc").unwrap()
1819 });
1820 self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?;
1821 }
1822
1823 let end_index = if i == 0 && self.threaded {
1824 thread_end.unwrap()
1825 } else {
1826 expr_shape[i]
1827 };
1828
1829 let loop_while =
1831 self.builder
1832 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
1833 let block = self.context.append_basic_block(self.fn_value(), name);
1834 self.builder
1835 .build_conditional_branch(loop_while, blocks[i], block)?;
1836 self.builder.position_at_end(block);
1837 preblock = block;
1838 }
1839
1840 if self.threaded {
1841 self.jit_end_threading(
1842 thread_start.unwrap(),
1843 thread_end.unwrap(),
1844 test_done.unwrap(),
1845 next.unwrap(),
1846 )?;
1847 }
1848 Ok(())
1849 }
1850
1851 fn jit_compile_sparse_contraction_block(
1852 &mut self,
1853 name: &str,
1854 elmt: &TensorBlock,
1855 translation: &Translation,
1856 ) -> Result<()> {
1857 match translation.source {
1858 TranslationFrom::SparseContraction { .. } => {}
1859 _ => {
1860 panic!("expected sparse contraction")
1861 }
1862 }
1863 let int_type = self.int_type;
1864
1865 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
1866 let translation_index = self
1867 .layout
1868 .get_translation_index(elmt.expr_layout(), elmt.layout())
1869 .unwrap();
1870 let translation_index = translation_index + translation.get_from_index_in_data_layout();
1871
1872 let final_contract_index =
1873 int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
1874 let (thread_start, thread_end, test_done, next) = if self.threaded {
1875 let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
1876 (Some(start), Some(end), Some(test_done), Some(next))
1877 } else {
1878 (None, None, None, None)
1879 };
1880
1881 let preblock = self.builder.get_insert_block().unwrap();
1882 let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
1883
1884 let block = self.context.append_basic_block(self.fn_value(), name);
1886 self.builder.build_unconditional_branch(block)?;
1887 self.builder.position_at_end(block);
1888
1889 let contract_index = self.builder.build_phi(int_type, "i")?;
1890 let contract_start = if self.threaded {
1891 thread_start.unwrap()
1892 } else {
1893 int_type.const_zero()
1894 };
1895 contract_index.add_incoming(&[(&contract_start, preblock)]);
1896
1897 let start_index = self.builder.build_int_add(
1898 int_type.const_int(translation_index.try_into().unwrap(), false),
1899 self.builder.build_int_mul(
1900 int_type.const_int(2, false),
1901 contract_index.as_basic_value().into_int_value(),
1902 name,
1903 )?,
1904 name,
1905 )?;
1906 let end_index =
1907 self.builder
1908 .build_int_add(start_index, int_type.const_int(1, false), name)?;
1909 let start_ptr = self.build_gep(
1910 self.int_type,
1911 *self.get_param("indices"),
1912 &[start_index],
1913 "start_index_ptr",
1914 )?;
1915 let start_contract = self
1916 .build_load(self.int_type, start_ptr, "start")?
1917 .into_int_value();
1918 let end_ptr = self.build_gep(
1919 self.int_type,
1920 *self.get_param("indices"),
1921 &[end_index],
1922 "end_index_ptr",
1923 )?;
1924 let end_contract = self
1925 .build_load(self.int_type, end_ptr, "end")?
1926 .into_int_value();
1927
1928 self.builder
1930 .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
1931
1932 let contract_block = self
1934 .context
1935 .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
1936 self.builder.build_unconditional_branch(contract_block)?;
1937 self.builder.position_at_end(contract_block);
1938
1939 let expr_index_phi = self.builder.build_phi(int_type, "j")?;
1940 expr_index_phi.add_incoming(&[(&start_contract, block)]);
1941
1942 let expr_index = expr_index_phi.as_basic_value().into_int_value();
1944 let elmt_index_mult_rank = self.builder.build_int_mul(
1945 expr_index,
1946 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
1947 name,
1948 )?;
1949 let indices_int = (0..elmt.expr_layout().rank())
1950 .map(|i| {
1951 let layout_index_plus_offset =
1952 int_type.const_int((layout_index + i).try_into().unwrap(), false);
1953 let curr_index = self.builder.build_int_add(
1954 elmt_index_mult_rank,
1955 layout_index_plus_offset,
1956 name,
1957 )?;
1958 let ptr = Self::get_ptr_to_index(
1959 &self.builder,
1960 self.int_type,
1961 self.get_param("indices"),
1962 curr_index,
1963 name,
1964 );
1965 let index = self.build_load(self.int_type, ptr, name)?.into_int_value();
1966 Ok(index)
1967 })
1968 .collect::<Result<Vec<_>, anyhow::Error>>()?;
1969
1970 let float_value = self.jit_compile_expr(
1972 name,
1973 elmt.expr(),
1974 indices_int.as_slice(),
1975 elmt,
1976 Some(expr_index),
1977 )?;
1978 let contract_sum_value = self
1979 .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
1980 .into_float_value();
1981 let new_contract_sum_value =
1982 self.builder
1983 .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
1984 self.builder
1985 .build_store(contract_sum_ptr, new_contract_sum_value)?;
1986
1987 let next_elmt_index =
1989 self.builder
1990 .build_int_add(expr_index, int_type.const_int(1, false), name)?;
1991 expr_index_phi.add_incoming(&[(&next_elmt_index, contract_block)]);
1992
1993 let loop_while = self.builder.build_int_compare(
1995 IntPredicate::ULT,
1996 next_elmt_index,
1997 end_contract,
1998 name,
1999 )?;
2000 let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2001 self.builder
2002 .build_conditional_branch(loop_while, contract_block, post_contract_block)?;
2003 self.builder.position_at_end(post_contract_block);
2004
2005 self.jit_compile_store(
2007 name,
2008 elmt,
2009 contract_index.as_basic_value().into_int_value(),
2010 new_contract_sum_value,
2011 translation,
2012 )?;
2013
2014 let next_contract_index = self.builder.build_int_add(
2016 contract_index.as_basic_value().into_int_value(),
2017 int_type.const_int(1, false),
2018 name,
2019 )?;
2020 contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2021
2022 let loop_while = self.builder.build_int_compare(
2024 IntPredicate::ULT,
2025 next_contract_index,
2026 thread_end.unwrap_or(final_contract_index),
2027 name,
2028 )?;
2029 let post_block = self.context.append_basic_block(self.fn_value(), name);
2030 self.builder
2031 .build_conditional_branch(loop_while, block, post_block)?;
2032 self.builder.position_at_end(post_block);
2033
2034 if self.threaded {
2035 self.jit_end_threading(
2036 thread_start.unwrap(),
2037 thread_end.unwrap(),
2038 test_done.unwrap(),
2039 next.unwrap(),
2040 )?;
2041 }
2042
2043 Ok(())
2044 }
2045
2046 fn jit_compile_sparse_block(
2049 &mut self,
2050 name: &str,
2051 elmt: &TensorBlock,
2052 translation: &Translation,
2053 ) -> Result<()> {
2054 let int_type = self.int_type;
2055
2056 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2057 let start_index = int_type.const_int(0, false);
2058 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2059
2060 let (thread_start, thread_end, test_done, next) = if self.threaded {
2061 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2062 (Some(start), Some(end), Some(test_done), Some(next))
2063 } else {
2064 (None, None, None, None)
2065 };
2066
2067 let preblock = self.builder.get_insert_block().unwrap();
2069 let mut block = self.context.append_basic_block(self.fn_value(), name);
2070 self.builder.build_unconditional_branch(block)?;
2071 self.builder.position_at_end(block);
2072
2073 let curr_index = self.builder.build_phi(int_type, "i")?;
2074 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2075
2076 let elmt_index = curr_index.as_basic_value().into_int_value();
2078 let elmt_index_mult_rank = self.builder.build_int_mul(
2079 elmt_index,
2080 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2081 name,
2082 )?;
2083 let indices_int = (0..elmt.expr_layout().rank())
2084 .map(|i| {
2085 let layout_index_plus_offset =
2086 int_type.const_int((layout_index + i).try_into().unwrap(), false);
2087 let curr_index = self.builder.build_int_add(
2088 elmt_index_mult_rank,
2089 layout_index_plus_offset,
2090 name,
2091 )?;
2092 let ptr = Self::get_ptr_to_index(
2093 &self.builder,
2094 self.int_type,
2095 self.get_param("indices"),
2096 curr_index,
2097 name,
2098 );
2099 Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2100 })
2101 .collect::<Result<Vec<_>, anyhow::Error>>()?;
2102
2103 let float_value = self.jit_compile_expr(
2105 name,
2106 elmt.expr(),
2107 indices_int.as_slice(),
2108 elmt,
2109 Some(elmt_index),
2110 )?;
2111
2112 block = self.jit_compile_broadcast_and_store(
2113 name,
2114 elmt,
2115 float_value,
2116 elmt_index,
2117 translation,
2118 block,
2119 )?;
2120
2121 let one = int_type.const_int(1, false);
2123 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2124 curr_index.add_incoming(&[(&next_index, block)]);
2125
2126 let loop_while = self.builder.build_int_compare(
2128 IntPredicate::ULT,
2129 next_index,
2130 thread_end.unwrap_or(end_index),
2131 name,
2132 )?;
2133 let post_block = self.context.append_basic_block(self.fn_value(), name);
2134 self.builder
2135 .build_conditional_branch(loop_while, block, post_block)?;
2136 self.builder.position_at_end(post_block);
2137
2138 if self.threaded {
2139 self.jit_end_threading(
2140 thread_start.unwrap(),
2141 thread_end.unwrap(),
2142 test_done.unwrap(),
2143 next.unwrap(),
2144 )?;
2145 }
2146
2147 Ok(())
2148 }
2149
2150 fn jit_compile_diagonal_block(
2152 &mut self,
2153 name: &str,
2154 elmt: &TensorBlock,
2155 translation: &Translation,
2156 ) -> Result<()> {
2157 let int_type = self.int_type;
2158
2159 let start_index = int_type.const_int(0, false);
2160 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2161
2162 let (thread_start, thread_end, test_done, next) = if self.threaded {
2163 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2164 (Some(start), Some(end), Some(test_done), Some(next))
2165 } else {
2166 (None, None, None, None)
2167 };
2168
2169 let preblock = self.builder.get_insert_block().unwrap();
2171 let mut block = self.context.append_basic_block(self.fn_value(), name);
2172 self.builder.build_unconditional_branch(block)?;
2173 self.builder.position_at_end(block);
2174
2175 let curr_index = self.builder.build_phi(int_type, "i")?;
2176 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2177
2178 let elmt_index = curr_index.as_basic_value().into_int_value();
2180 let indices_int: Vec<IntValue> =
2181 (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2182
2183 let float_value = self.jit_compile_expr(
2185 name,
2186 elmt.expr(),
2187 indices_int.as_slice(),
2188 elmt,
2189 Some(elmt_index),
2190 )?;
2191
2192 block = self.jit_compile_broadcast_and_store(
2194 name,
2195 elmt,
2196 float_value,
2197 elmt_index,
2198 translation,
2199 block,
2200 )?;
2201
2202 let one = int_type.const_int(1, false);
2204 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2205 curr_index.add_incoming(&[(&next_index, block)]);
2206
2207 let loop_while = self.builder.build_int_compare(
2209 IntPredicate::ULT,
2210 next_index,
2211 thread_end.unwrap_or(end_index),
2212 name,
2213 )?;
2214 let post_block = self.context.append_basic_block(self.fn_value(), name);
2215 self.builder
2216 .build_conditional_branch(loop_while, block, post_block)?;
2217 self.builder.position_at_end(post_block);
2218
2219 if self.threaded {
2220 self.jit_end_threading(
2221 thread_start.unwrap(),
2222 thread_end.unwrap(),
2223 test_done.unwrap(),
2224 next.unwrap(),
2225 )?;
2226 }
2227
2228 Ok(())
2229 }
2230
2231 fn jit_compile_broadcast_and_store(
2232 &mut self,
2233 name: &str,
2234 elmt: &TensorBlock,
2235 float_value: FloatValue<'ctx>,
2236 expr_index: IntValue<'ctx>,
2237 translation: &Translation,
2238 pre_block: BasicBlock<'ctx>,
2239 ) -> Result<BasicBlock<'ctx>> {
2240 let int_type = self.int_type;
2241 let one = int_type.const_int(1, false);
2242 let zero = int_type.const_int(0, false);
2243 match translation.source {
2244 TranslationFrom::Broadcast {
2245 broadcast_by: _,
2246 broadcast_len,
2247 } => {
2248 let bcast_start_index = zero;
2249 let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2250
2251 let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2253 self.builder.build_unconditional_branch(bcast_block)?;
2254 self.builder.position_at_end(bcast_block);
2255 let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2256 bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2257
2258 let store_index = self.builder.build_int_add(
2260 self.builder
2261 .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2262 bcast_index.as_basic_value().into_int_value(),
2263 "bcast_store_index",
2264 )?;
2265 self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2266
2267 let bcast_next_index = self.builder.build_int_add(
2269 bcast_index.as_basic_value().into_int_value(),
2270 one,
2271 name,
2272 )?;
2273 bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2274
2275 let bcast_cond = self.builder.build_int_compare(
2277 IntPredicate::ULT,
2278 bcast_next_index,
2279 bcast_end_index,
2280 "broadcast_cond",
2281 )?;
2282 let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2283 self.builder
2284 .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2285 self.builder.position_at_end(post_bcast_block);
2286
2287 Ok(post_bcast_block)
2289 }
2290 TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2291 self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2292 Ok(pre_block)
2293 }
2294 _ => Err(anyhow!("Invalid translation")),
2295 }
2296 }
2297
2298 fn jit_compile_store(
2299 &mut self,
2300 name: &str,
2301 elmt: &TensorBlock,
2302 store_index: IntValue<'ctx>,
2303 float_value: FloatValue<'ctx>,
2304 translation: &Translation,
2305 ) -> Result<()> {
2306 let int_type = self.int_type;
2307 let res_index = match &translation.target {
2308 TranslationTo::Contiguous { start, end: _ } => {
2309 let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2310 self.builder.build_int_add(start_const, store_index, name)?
2311 }
2312 TranslationTo::Sparse { indices: _ } => {
2313 let translate_index = self
2315 .layout
2316 .get_translation_index(elmt.expr_layout(), elmt.layout())
2317 .unwrap();
2318 let translate_store_index =
2319 translate_index + translation.get_to_index_in_data_layout();
2320 let translate_store_index =
2321 int_type.const_int(translate_store_index.try_into().unwrap(), false);
2322 let elmt_index_strided = store_index;
2323 let curr_index =
2324 self.builder
2325 .build_int_add(elmt_index_strided, translate_store_index, name)?;
2326 let ptr = Self::get_ptr_to_index(
2327 &self.builder,
2328 self.int_type,
2329 self.get_param("indices"),
2330 curr_index,
2331 name,
2332 );
2333 self.build_load(self.int_type, ptr, name)?.into_int_value()
2334 }
2335 };
2336
2337 let resi_ptr = Self::get_ptr_to_index(
2338 &self.builder,
2339 self.real_type,
2340 &self.tensor_ptr(),
2341 res_index,
2342 name,
2343 );
2344 self.builder.build_store(resi_ptr, float_value)?;
2345 Ok(())
2346 }
2347
2348 fn jit_compile_expr(
2349 &mut self,
2350 name: &str,
2351 expr: &Ast,
2352 index: &[IntValue<'ctx>],
2353 elmt: &TensorBlock,
2354 expr_index: Option<IntValue<'ctx>>,
2355 ) -> Result<FloatValue<'ctx>> {
2356 let name = elmt.name().unwrap_or(name);
2357 match &expr.kind {
2358 AstKind::Binop(binop) => {
2359 let lhs =
2360 self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2361 let rhs =
2362 self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2363 match binop.op {
2364 '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2365 '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2366 '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2367 '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2368 unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2369 }
2370 }
2371 AstKind::Monop(monop) => {
2372 let child =
2373 self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2374 match monop.op {
2375 '-' => Ok(self.builder.build_float_neg(child, name)?),
2376 unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2377 }
2378 }
2379 AstKind::Call(call) => match self.get_function(call.fn_name) {
2380 Some(function) => {
2381 let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2382 for arg in call.args.iter() {
2383 let arg_val =
2384 self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2385 args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2386 }
2387 let ret_value = self
2388 .builder
2389 .build_call(function, args.as_slice(), name)?
2390 .try_as_basic_value()
2391 .left()
2392 .unwrap()
2393 .into_float_value();
2394 Ok(ret_value)
2395 }
2396 None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2397 },
2398 AstKind::CallArg(arg) => {
2399 self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2400 }
2401 AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2402 AstKind::Name(iname) => {
2403 let ptr = self.get_param(iname.name);
2404 let layout = self.layout.get_layout(iname.name).unwrap();
2405 let iname_elmt_index = if layout.is_dense() {
2406 let mut no_transform = true;
2408 let mut iname_index = Vec::new();
2409 for (i, c) in iname.indices.iter().enumerate() {
2410 let pi = elmt
2413 .indices()
2414 .iter()
2415 .position(|x| x == c)
2416 .unwrap_or(elmt.indices().len());
2417 iname_index.push(index[pi]);
2418 no_transform = no_transform && pi == i;
2419 }
2420 if !iname_index.is_empty() {
2423 let mut iname_elmt_index = *iname_index.last().unwrap();
2424 let mut stride = 1u64;
2425 for i in (0..iname_index.len() - 1).rev() {
2426 let iname_i = iname_index[i];
2427 let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2428 stride *= shapei;
2429 let stride_intval = self.context.i32_type().const_int(stride, false);
2430 let stride_mul_i =
2431 self.builder.build_int_mul(stride_intval, iname_i, name)?;
2432 iname_elmt_index =
2433 self.builder
2434 .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2435 }
2436 Some(iname_elmt_index)
2437 } else {
2438 let zero = self.context.i32_type().const_int(0, false);
2439 Some(zero)
2440 }
2441 } else if layout.is_sparse() || layout.is_diagonal() {
2442 expr_index
2445 } else {
2446 panic!("unexpected layout");
2447 };
2448 let value_ptr = match iname_elmt_index {
2449 Some(index) => {
2450 Self::get_ptr_to_index(&self.builder, self.real_type, ptr, index, name)
2451 }
2452 None => *ptr,
2453 };
2454 Ok(self
2455 .build_load(self.real_type, value_ptr, name)?
2456 .into_float_value())
2457 }
2458 AstKind::NamedGradient(name) => {
2459 let name_str = name.to_string();
2460 let ptr = self.get_param(name_str.as_str());
2461 Ok(self
2462 .build_load(self.real_type, *ptr, name_str.as_str())?
2463 .into_float_value())
2464 }
2465 AstKind::Index(_) => todo!(),
2466 AstKind::Slice(_) => todo!(),
2467 AstKind::Integer(_) => todo!(),
2468 _ => panic!("unexprected astkind"),
2469 }
2470 }
2471
2472 fn clear(&mut self) {
2473 self.variables.clear();
2474 self.fn_value_opt = None;
2476 self.tensor_ptr_opt = None;
2477 }
2478
2479 fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
2480 match arg {
2481 BasicValueEnum::PointerValue(v) => v,
2482 BasicValueEnum::FloatValue(v) => {
2483 let alloca = self
2484 .create_entry_block_builder()
2485 .build_alloca(arg.get_type(), name)
2486 .unwrap();
2487 self.builder.build_store(alloca, v).unwrap();
2488 alloca
2489 }
2490 BasicValueEnum::IntValue(v) => {
2491 let alloca = self
2492 .create_entry_block_builder()
2493 .build_alloca(arg.get_type(), name)
2494 .unwrap();
2495 self.builder.build_store(alloca, v).unwrap();
2496 alloca
2497 }
2498 _ => unreachable!(),
2499 }
2500 }
2501
2502 pub fn compile_set_u0<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2503 self.clear();
2504 let void_type = self.context.void_type();
2505 let fn_type = void_type.fn_type(
2506 &[
2507 self.real_ptr_type.into(),
2508 self.real_ptr_type.into(),
2509 self.int_type.into(),
2510 self.int_type.into(),
2511 ],
2512 false,
2513 );
2514 let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
2515 let function = self.module.add_function("set_u0", fn_type, None);
2516
2517 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2519 let noalign = self.context.create_enum_attribute(alias_id, 0);
2520 for i in &[0, 1] {
2521 function.add_attribute(AttributeLoc::Param(*i), noalign);
2522 }
2523
2524 let basic_block = self.context.append_basic_block(function, "entry");
2525 self.fn_value_opt = Some(function);
2526 self.builder.position_at_end(basic_block);
2527
2528 for (i, arg) in function.get_param_iter().enumerate() {
2529 let name = fn_arg_names[i];
2530 let alloca = self.function_arg_alloca(name, arg);
2531 self.insert_param(name, alloca);
2532 }
2533
2534 self.insert_data(model);
2535 self.insert_indices();
2536
2537 let mut nbarriers = 0;
2538 let total_barriers = (model.input_dep_defns().len() + 1) as u64;
2539 #[allow(clippy::explicit_counter_loop)]
2540 for a in model.input_dep_defns() {
2541 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2542 self.jit_compile_call_barrier(nbarriers, total_barriers);
2543 nbarriers += 1;
2544 }
2545
2546 self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")))?;
2547 self.jit_compile_call_barrier(nbarriers, total_barriers);
2548
2549 self.builder.build_return(None)?;
2550
2551 if function.verify(true) {
2552 Ok(function)
2553 } else {
2554 function.print_to_stderr();
2555 unsafe {
2556 function.delete();
2557 }
2558 Err(anyhow!("Invalid generated function."))
2559 }
2560 }
2561
2562 pub fn compile_calc_out<'m>(
2563 &mut self,
2564 model: &'m DiscreteModel,
2565 include_constants: bool,
2566 ) -> Result<FunctionValue<'ctx>> {
2567 self.clear();
2568 let void_type = self.context.void_type();
2569 let fn_type = void_type.fn_type(
2570 &[
2571 self.real_type.into(),
2572 self.real_ptr_type.into(),
2573 self.real_ptr_type.into(),
2574 self.real_ptr_type.into(),
2575 self.int_type.into(),
2576 self.int_type.into(),
2577 ],
2578 false,
2579 );
2580 let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
2581 let function_name = if include_constants {
2582 "calc_out_full"
2583 } else {
2584 "calc_out"
2585 };
2586 let function = self.module.add_function(function_name, fn_type, None);
2587
2588 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2590 let noalign = self.context.create_enum_attribute(alias_id, 0);
2591 for i in &[1, 2] {
2592 function.add_attribute(AttributeLoc::Param(*i), noalign);
2593 }
2594
2595 let basic_block = self.context.append_basic_block(function, "entry");
2596 self.fn_value_opt = Some(function);
2597 self.builder.position_at_end(basic_block);
2598
2599 for (i, arg) in function.get_param_iter().enumerate() {
2600 let name = fn_arg_names[i];
2601 let alloca = self.function_arg_alloca(name, arg);
2602 self.insert_param(name, alloca);
2603 }
2604
2605 self.insert_state(model.state());
2606 self.insert_data(model);
2607 self.insert_indices();
2608
2609 if let Some(out) = model.out() {
2615 let mut nbarriers = 0;
2616 let mut total_barriers =
2617 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2618 if include_constants {
2619 total_barriers += model.input_dep_defns().len() as u64;
2620 for tensor in model.input_dep_defns() {
2622 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2623 self.jit_compile_call_barrier(nbarriers, total_barriers);
2624 nbarriers += 1;
2625 }
2626 }
2627
2628 for tensor in model.time_dep_defns() {
2630 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2631 self.jit_compile_call_barrier(nbarriers, total_barriers);
2632 nbarriers += 1;
2633 }
2634
2635 #[allow(clippy::explicit_counter_loop)]
2637 for a in model.state_dep_defns() {
2638 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2639 self.jit_compile_call_barrier(nbarriers, total_barriers);
2640 nbarriers += 1;
2641 }
2642
2643 self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())))?;
2644 self.jit_compile_call_barrier(nbarriers, total_barriers);
2645 }
2646 self.builder.build_return(None)?;
2647
2648 if function.verify(true) {
2649 Ok(function)
2650 } else {
2651 function.print_to_stderr();
2652 unsafe {
2653 function.delete();
2654 }
2655 Err(anyhow!("Invalid generated function."))
2656 }
2657 }
2658
2659 pub fn compile_calc_stop<'m>(
2660 &mut self,
2661 model: &'m DiscreteModel,
2662 ) -> Result<FunctionValue<'ctx>> {
2663 self.clear();
2664 let void_type = self.context.void_type();
2665 let fn_type = void_type.fn_type(
2666 &[
2667 self.real_type.into(),
2668 self.real_ptr_type.into(),
2669 self.real_ptr_type.into(),
2670 self.real_ptr_type.into(),
2671 self.int_type.into(),
2672 self.int_type.into(),
2673 ],
2674 false,
2675 );
2676 let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
2677 let function = self.module.add_function("calc_stop", fn_type, None);
2678
2679 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2681 let noalign = self.context.create_enum_attribute(alias_id, 0);
2682 for i in &[1, 2, 3] {
2683 function.add_attribute(AttributeLoc::Param(*i), noalign);
2684 }
2685
2686 let basic_block = self.context.append_basic_block(function, "entry");
2687 self.fn_value_opt = Some(function);
2688 self.builder.position_at_end(basic_block);
2689
2690 for (i, arg) in function.get_param_iter().enumerate() {
2691 let name = fn_arg_names[i];
2692 let alloca = self.function_arg_alloca(name, arg);
2693 self.insert_param(name, alloca);
2694 }
2695
2696 self.insert_state(model.state());
2697 self.insert_data(model);
2698 self.insert_indices();
2699
2700 if let Some(stop) = model.stop() {
2701 let mut nbarriers = 0;
2703 let total_barriers =
2704 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2705 for tensor in model.time_dep_defns() {
2706 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2707 self.jit_compile_call_barrier(nbarriers, total_barriers);
2708 nbarriers += 1;
2709 }
2710
2711 for a in model.state_dep_defns() {
2713 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2714 self.jit_compile_call_barrier(nbarriers, total_barriers);
2715 nbarriers += 1;
2716 }
2717
2718 let res_ptr = self.get_param("root");
2719 self.jit_compile_tensor(stop, Some(*res_ptr))?;
2720 self.jit_compile_call_barrier(nbarriers, total_barriers);
2721 }
2722 self.builder.build_return(None)?;
2723
2724 if function.verify(true) {
2725 Ok(function)
2726 } else {
2727 function.print_to_stderr();
2728 unsafe {
2729 function.delete();
2730 }
2731 Err(anyhow!("Invalid generated function."))
2732 }
2733 }
2734
2735 pub fn compile_rhs<'m>(
2736 &mut self,
2737 model: &'m DiscreteModel,
2738 include_constants: bool,
2739 ) -> Result<FunctionValue<'ctx>> {
2740 self.clear();
2741 let void_type = self.context.void_type();
2742 let fn_type = void_type.fn_type(
2743 &[
2744 self.real_type.into(),
2745 self.real_ptr_type.into(),
2746 self.real_ptr_type.into(),
2747 self.real_ptr_type.into(),
2748 self.int_type.into(),
2749 self.int_type.into(),
2750 ],
2751 false,
2752 );
2753 let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
2754 let function_name = if include_constants { "rhs_full" } else { "rhs" };
2755 let function = self.module.add_function(function_name, fn_type, None);
2756
2757 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2759 let noalign = self.context.create_enum_attribute(alias_id, 0);
2760 for i in &[1, 2, 3] {
2761 function.add_attribute(AttributeLoc::Param(*i), noalign);
2762 }
2763
2764 let basic_block = self.context.append_basic_block(function, "entry");
2765 self.fn_value_opt = Some(function);
2766 self.builder.position_at_end(basic_block);
2767
2768 for (i, arg) in function.get_param_iter().enumerate() {
2769 let name = fn_arg_names[i];
2770 let alloca = self.function_arg_alloca(name, arg);
2771 self.insert_param(name, alloca);
2772 }
2773
2774 self.insert_state(model.state());
2775 self.insert_data(model);
2776 self.insert_indices();
2777
2778 let mut nbarriers = 0;
2779 let mut total_barriers =
2780 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
2781 if include_constants {
2782 total_barriers += model.input_dep_defns().len() as u64;
2783 for tensor in model.input_dep_defns() {
2785 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2786 self.jit_compile_call_barrier(nbarriers, total_barriers);
2787 nbarriers += 1;
2788 }
2789 }
2790
2791 for tensor in model.time_dep_defns() {
2793 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2794 self.jit_compile_call_barrier(nbarriers, total_barriers);
2795 nbarriers += 1;
2796 }
2797
2798 for a in model.state_dep_defns() {
2800 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2801 self.jit_compile_call_barrier(nbarriers, total_barriers);
2802 nbarriers += 1;
2803 }
2804
2805 let res_ptr = self.get_param("rr");
2807 self.jit_compile_tensor(model.rhs(), Some(*res_ptr))?;
2808 self.jit_compile_call_barrier(nbarriers, total_barriers);
2809
2810 self.builder.build_return(None)?;
2811
2812 if function.verify(true) {
2813 Ok(function)
2814 } else {
2815 function.print_to_stderr();
2816 unsafe {
2817 function.delete();
2818 }
2819 Err(anyhow!("Invalid generated function."))
2820 }
2821 }
2822
2823 pub fn compile_mass<'m>(&mut self, model: &'m DiscreteModel) -> Result<FunctionValue<'ctx>> {
2824 self.clear();
2825 let void_type = self.context.void_type();
2826 let fn_type = void_type.fn_type(
2827 &[
2828 self.real_type.into(),
2829 self.real_ptr_type.into(),
2830 self.real_ptr_type.into(),
2831 self.real_ptr_type.into(),
2832 self.int_type.into(),
2833 self.int_type.into(),
2834 ],
2835 false,
2836 );
2837 let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
2838 let function = self.module.add_function("mass", fn_type, None);
2839
2840 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2842 let noalign = self.context.create_enum_attribute(alias_id, 0);
2843 for i in &[1, 2, 3] {
2844 function.add_attribute(AttributeLoc::Param(*i), noalign);
2845 }
2846
2847 let basic_block = self.context.append_basic_block(function, "entry");
2848 self.fn_value_opt = Some(function);
2849 self.builder.position_at_end(basic_block);
2850
2851 for (i, arg) in function.get_param_iter().enumerate() {
2852 let name = fn_arg_names[i];
2853 let alloca = self.function_arg_alloca(name, arg);
2854 self.insert_param(name, alloca);
2855 }
2856
2857 if model.state_dot().is_some() && model.lhs().is_some() {
2859 let state_dot = model.state_dot().unwrap();
2860 let lhs = model.lhs().unwrap();
2861
2862 self.insert_dot_state(state_dot);
2863 self.insert_data(model);
2864 self.insert_indices();
2865
2866 let mut nbarriers = 0;
2868 let total_barriers =
2869 (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
2870 for tensor in model.time_dep_defns() {
2871 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)))?;
2872 self.jit_compile_call_barrier(nbarriers, total_barriers);
2873 nbarriers += 1;
2874 }
2875
2876 for a in model.dstate_dep_defns() {
2877 self.jit_compile_tensor(a, Some(*self.get_var(a)))?;
2878 self.jit_compile_call_barrier(nbarriers, total_barriers);
2879 nbarriers += 1;
2880 }
2881
2882 let res_ptr = self.get_param("rr");
2884 self.jit_compile_tensor(lhs, Some(*res_ptr))?;
2885 self.jit_compile_call_barrier(nbarriers, total_barriers);
2886 }
2887
2888 self.builder.build_return(None)?;
2889
2890 if function.verify(true) {
2891 Ok(function)
2892 } else {
2893 function.print_to_stderr();
2894 unsafe {
2895 function.delete();
2896 }
2897 Err(anyhow!("Invalid generated function."))
2898 }
2899 }
2900
2901 pub fn compile_gradient(
2902 &mut self,
2903 original_function: FunctionValue<'ctx>,
2904 args_type: &[CompileGradientArgType],
2905 mode: CompileMode,
2906 fn_name: &str,
2907 ) -> Result<FunctionValue<'ctx>> {
2908 self.clear();
2909
2910 let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
2912
2913 let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
2914
2915 let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
2916 let mut start_param_index: Vec<u32> = Vec::new();
2917 let mut ptr_arg_indices: Vec<u32> = Vec::new();
2918 for (i, arg) in original_function.get_param_iter().enumerate() {
2919 start_param_index.push(u32::try_from(fn_type.len()).unwrap());
2920 let arg_type = arg.get_type();
2921 fn_type.push(arg_type.into());
2922
2923 enzyme_fn_type.push(self.int_type.into());
2925 enzyme_fn_type.push(arg.get_type().into());
2926
2927 if arg_type.is_pointer_type() {
2928 ptr_arg_indices.push(u32::try_from(i).unwrap());
2929 }
2930
2931 match args_type[i] {
2932 CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
2933 fn_type.push(arg.get_type().into());
2934 enzyme_fn_type.push(arg.get_type().into());
2935 }
2936 CompileGradientArgType::Const => {}
2937 }
2938 }
2939 let void_type = self.context.void_type();
2940 let fn_type = void_type.fn_type(fn_type.as_slice(), false);
2941 let function = self.module.add_function(fn_name, fn_type, None);
2942
2943 let alias_id = Attribute::get_named_enum_kind_id("noalias");
2945 let noalign = self.context.create_enum_attribute(alias_id, 0);
2946 for i in ptr_arg_indices {
2947 function.add_attribute(AttributeLoc::Param(i), noalign);
2948 }
2949
2950 let basic_block = self.context.append_basic_block(function, "entry");
2951 self.fn_value_opt = Some(function);
2952 self.builder.position_at_end(basic_block);
2953
2954 let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
2955 let mut input_activity = Vec::new();
2956 let mut arg_trees = Vec::new();
2957 for (i, arg) in original_function.get_param_iter().enumerate() {
2958 let param_index = start_param_index[i];
2959 let fn_arg = function.get_nth_param(param_index).unwrap();
2960
2961 let concrete_type = match arg.get_type() {
2964 BasicTypeEnum::PointerType(_) => CConcreteType_DT_Pointer,
2965 BasicTypeEnum::FloatType(t) => {
2966 if t == self.context.f64_type() {
2967 CConcreteType_DT_Double
2968 } else {
2969 panic!("unsupported type")
2970 }
2971 }
2972 BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
2973 _ => panic!("unsupported type"),
2974 };
2975 let new_tree = unsafe {
2976 EnzymeNewTypeTreeCT(
2977 concrete_type,
2978 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
2979 )
2980 };
2981 unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
2982
2983 if concrete_type == CConcreteType_DT_Pointer {
2985 let inner_concrete_type = CConcreteType_DT_Double;
2987 let inner_new_tree = unsafe {
2988 EnzymeNewTypeTreeCT(
2989 inner_concrete_type,
2990 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
2991 )
2992 };
2993 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
2994 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
2995 unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
2996 }
2997
2998 arg_trees.push(new_tree);
2999 match args_type[i] {
3000 CompileGradientArgType::Dup => {
3001 enzyme_fn_args.push(fn_arg.into());
3003
3004 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3006 enzyme_fn_args.push(fn_darg.into());
3007
3008 input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
3009 }
3010 CompileGradientArgType::DupNoNeed => {
3011 enzyme_fn_args.push(fn_arg.into());
3013
3014 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
3016 enzyme_fn_args.push(fn_darg.into());
3017
3018 input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
3019 }
3020 CompileGradientArgType::Const => {
3021 enzyme_fn_args.push(fn_arg.into());
3023
3024 input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
3025 }
3026 }
3027 }
3028 let ret_primary_ret = false;
3030 let diff_ret = false;
3031 let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
3032 let ret_tree = unsafe {
3033 EnzymeNewTypeTreeCT(
3034 CConcreteType_DT_Anything,
3035 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3036 )
3037 };
3038
3039 let fnc_opt_base = true;
3041 let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
3042
3043 let kv_tmp = IntList {
3044 data: std::ptr::null_mut(),
3045 size: 0,
3046 };
3047 let mut known_values = vec![kv_tmp; input_activity.len()];
3048
3049 let fn_type_info = CFnTypeInfo {
3050 Arguments: arg_trees.as_mut_ptr(),
3051 Return: ret_tree,
3052 KnownValues: known_values.as_mut_ptr(),
3053 };
3054
3055 let type_analysis: EnzymeTypeAnalysisRef =
3056 unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
3057
3058 let mut args_uncacheable = vec![0; arg_trees.len()];
3059
3060 let enzyme_function = match mode {
3061 CompileMode::Forward | CompileMode::ForwardSens => unsafe {
3062 EnzymeCreateForwardDiff(
3063 logic_ref, std::ptr::null_mut(),
3065 std::ptr::null_mut(),
3066 original_function.as_value_ref(),
3067 ret_activity, input_activity.as_mut_ptr(),
3069 input_activity.len(), type_analysis, ret_primary_ret as u8,
3072 CDerivativeMode_DEM_ForwardMode, 1, 0, 1, std::ptr::null_mut(),
3077 fn_type_info, args_uncacheable.as_mut_ptr(),
3079 args_uncacheable.len(), std::ptr::null_mut(), )
3082 },
3083 CompileMode::Reverse | CompileMode::ReverseSens => {
3084 let mut call_enzyme = || unsafe {
3085 EnzymeCreatePrimalAndGradient(
3086 logic_ref,
3087 std::ptr::null_mut(),
3088 std::ptr::null_mut(),
3089 original_function.as_value_ref(),
3090 ret_activity,
3091 input_activity.as_mut_ptr(),
3092 input_activity.len(),
3093 type_analysis,
3094 ret_primary_ret as u8,
3095 diff_ret as u8,
3096 CDerivativeMode_DEM_ReverseModeCombined,
3097 0,
3098 1,
3099 1,
3100 std::ptr::null_mut(),
3101 0,
3102 fn_type_info,
3103 args_uncacheable.as_mut_ptr(),
3104 args_uncacheable.len(),
3105 std::ptr::null_mut(),
3106 if self.threaded { 1 } else { 0 },
3107 )
3108 };
3109 if self.threaded {
3110 let _lock = my_mutex.lock().unwrap();
3112 let barrier_string = CString::new("barrier").unwrap();
3113 unsafe {
3114 EnzymeRegisterCallHandler(
3115 barrier_string.as_ptr(),
3116 Some(fwd_handler),
3117 Some(rev_handler),
3118 )
3119 };
3120 let ret = call_enzyme();
3121 unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
3123 ret
3124 } else {
3125 call_enzyme()
3126 }
3127 }
3128 };
3129
3130 unsafe { FreeEnzymeLogic(logic_ref) };
3132 unsafe { FreeTypeAnalysis(type_analysis) };
3133 unsafe { EnzymeFreeTypeTree(ret_tree) };
3134 for tree in arg_trees {
3135 unsafe { EnzymeFreeTypeTree(tree) };
3136 }
3137
3138 let enzyme_function =
3140 unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
3141 self.builder
3142 .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
3143
3144 self.builder.build_return(None)?;
3146
3147 if function.verify(true) {
3148 Ok(function)
3149 } else {
3150 function.print_to_stderr();
3151 enzyme_function.print_to_stderr();
3152 unsafe {
3153 function.delete();
3154 }
3155 Err(anyhow!("Invalid generated function."))
3156 }
3157 }
3158
3159 pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3160 self.clear();
3161 let fn_type = self.context.void_type().fn_type(
3162 &[
3163 self.int_ptr_type.into(),
3164 self.int_ptr_type.into(),
3165 self.int_ptr_type.into(),
3166 self.int_ptr_type.into(),
3167 self.int_ptr_type.into(),
3168 self.int_ptr_type.into(),
3169 ],
3170 false,
3171 );
3172
3173 let function = self.module.add_function("get_dims", fn_type, None);
3174 let block = self.context.append_basic_block(function, "entry");
3175 let fn_arg_names = &["states", "inputs", "outputs", "data", "stop", "has_mass"];
3176 self.builder.position_at_end(block);
3177
3178 for (i, arg) in function.get_param_iter().enumerate() {
3179 let name = fn_arg_names[i];
3180 let alloca = self.function_arg_alloca(name, arg);
3181 self.insert_param(name, alloca);
3182 }
3183
3184 self.insert_indices();
3185
3186 let number_of_states = model.state().nnz() as u64;
3187 let number_of_inputs = model.inputs().iter().fold(0, |acc, x| acc + x.nnz()) as u64;
3188 let number_of_outputs = match model.out() {
3189 Some(out) => out.nnz() as u64,
3190 None => 0,
3191 };
3192 let number_of_stop = if let Some(stop) = model.stop() {
3193 stop.nnz() as u64
3194 } else {
3195 0
3196 };
3197 let has_mass = match model.lhs().is_some() {
3198 true => 1u64,
3199 false => 0u64,
3200 };
3201 let data_len = self.layout.data().len() as u64;
3202 self.builder.build_store(
3203 *self.get_param("states"),
3204 self.int_type.const_int(number_of_states, false),
3205 )?;
3206 self.builder.build_store(
3207 *self.get_param("inputs"),
3208 self.int_type.const_int(number_of_inputs, false),
3209 )?;
3210 self.builder.build_store(
3211 *self.get_param("outputs"),
3212 self.int_type.const_int(number_of_outputs, false),
3213 )?;
3214 self.builder.build_store(
3215 *self.get_param("data"),
3216 self.int_type.const_int(data_len, false),
3217 )?;
3218 self.builder.build_store(
3219 *self.get_param("stop"),
3220 self.int_type.const_int(number_of_stop, false),
3221 )?;
3222 self.builder.build_store(
3223 *self.get_param("has_mass"),
3224 self.int_type.const_int(has_mass, false),
3225 )?;
3226 self.builder.build_return(None)?;
3227
3228 if function.verify(true) {
3229 Ok(function)
3230 } else {
3231 function.print_to_stderr();
3232 unsafe {
3233 function.delete();
3234 }
3235 Err(anyhow!("Invalid generated function."))
3236 }
3237 }
3238
3239 pub fn compile_get_tensor(
3240 &mut self,
3241 model: &DiscreteModel,
3242 name: &str,
3243 ) -> Result<FunctionValue<'ctx>> {
3244 self.clear();
3245 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3246 let fn_type = self.context.void_type().fn_type(
3247 &[
3248 self.real_ptr_type.into(),
3249 real_ptr_ptr_type.into(),
3250 self.int_ptr_type.into(),
3251 ],
3252 false,
3253 );
3254 let function_name = format!("get_tensor_{name}");
3255 let function = self
3256 .module
3257 .add_function(function_name.as_str(), fn_type, None);
3258 let basic_block = self.context.append_basic_block(function, "entry");
3259 self.fn_value_opt = Some(function);
3260
3261 let fn_arg_names = &["data", "tensor_data", "tensor_size"];
3262 self.builder.position_at_end(basic_block);
3263
3264 for (i, arg) in function.get_param_iter().enumerate() {
3265 let name = fn_arg_names[i];
3266 let alloca = self.function_arg_alloca(name, arg);
3267 self.insert_param(name, alloca);
3268 }
3269
3270 self.insert_data(model);
3271 let ptr = self.get_param(name);
3272 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3273 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3274 self.builder
3275 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3276 self.builder
3277 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3278 self.builder.build_return(None)?;
3279
3280 if function.verify(true) {
3281 Ok(function)
3282 } else {
3283 function.print_to_stderr();
3284 unsafe {
3285 function.delete();
3286 }
3287 Err(anyhow!("Invalid generated function."))
3288 }
3289 }
3290
3291 pub fn compile_get_constant(
3292 &mut self,
3293 model: &DiscreteModel,
3294 name: &str,
3295 ) -> Result<FunctionValue<'ctx>> {
3296 self.clear();
3297 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
3298 let fn_type = self
3299 .context
3300 .void_type()
3301 .fn_type(&[real_ptr_ptr_type.into(), self.int_ptr_type.into()], false);
3302 let function_name = format!("get_constant_{name}");
3303 let function = self
3304 .module
3305 .add_function(function_name.as_str(), fn_type, None);
3306 let basic_block = self.context.append_basic_block(function, "entry");
3307 self.fn_value_opt = Some(function);
3308
3309 let fn_arg_names = &["tensor_data", "tensor_size"];
3310 self.builder.position_at_end(basic_block);
3311
3312 for (i, arg) in function.get_param_iter().enumerate() {
3313 let name = fn_arg_names[i];
3314 let alloca = self.function_arg_alloca(name, arg);
3315 self.insert_param(name, alloca);
3316 }
3317
3318 self.insert_constants(model);
3319 let ptr = self.get_param(name);
3320 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
3321 let tensor_size_value = self.int_type.const_int(tensor_size, false);
3322 self.builder
3323 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
3324 self.builder
3325 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
3326 self.builder.build_return(None)?;
3327
3328 if function.verify(true) {
3329 Ok(function)
3330 } else {
3331 function.print_to_stderr();
3332 unsafe {
3333 function.delete();
3334 }
3335 Err(anyhow!("Invalid generated function."))
3336 }
3337 }
3338
3339 pub fn compile_inputs(
3340 &mut self,
3341 model: &DiscreteModel,
3342 is_get: bool,
3343 ) -> Result<FunctionValue<'ctx>> {
3344 self.clear();
3345 let void_type = self.context.void_type();
3346 let fn_type = void_type.fn_type(
3347 &[self.real_ptr_type.into(), self.real_ptr_type.into()],
3348 false,
3349 );
3350 let function_name = if is_get { "get_inputs" } else { "set_inputs" };
3351 let function = self.module.add_function(function_name, fn_type, None);
3352 let mut block = self.context.append_basic_block(function, "entry");
3353 self.fn_value_opt = Some(function);
3354
3355 let fn_arg_names = &["inputs", "data"];
3356 self.builder.position_at_end(block);
3357
3358 for (i, arg) in function.get_param_iter().enumerate() {
3359 let name = fn_arg_names[i];
3360 let alloca = self.function_arg_alloca(name, arg);
3361 self.insert_param(name, alloca);
3362 }
3363
3364 let mut inputs_index = 0usize;
3365 for input in model.inputs() {
3366 let name = format!("input_{}", input.name());
3367 self.insert_tensor(input, false);
3368 let ptr = self.get_var(input);
3369 let inputs_start_index = self.int_type.const_int(inputs_index as u64, false);
3371 let start_index = self.int_type.const_int(0, false);
3372 let end_index = self
3373 .int_type
3374 .const_int(input.nnz().try_into().unwrap(), false);
3375
3376 let input_block = self.context.append_basic_block(function, name.as_str());
3377 self.builder.build_unconditional_branch(input_block)?;
3378 self.builder.position_at_end(input_block);
3379 let index = self.builder.build_phi(self.int_type, "i")?;
3380 index.add_incoming(&[(&start_index, block)]);
3381
3382 let curr_input_index = index.as_basic_value().into_int_value();
3384 let input_ptr = Self::get_ptr_to_index(
3385 &self.builder,
3386 self.real_type,
3387 ptr,
3388 curr_input_index,
3389 name.as_str(),
3390 );
3391 let curr_inputs_index =
3392 self.builder
3393 .build_int_add(inputs_start_index, curr_input_index, name.as_str())?;
3394 let inputs_ptr = Self::get_ptr_to_index(
3395 &self.builder,
3396 self.real_type,
3397 self.get_param("inputs"),
3398 curr_inputs_index,
3399 name.as_str(),
3400 );
3401 if is_get {
3402 let input_value = self
3403 .build_load(self.real_type, input_ptr, name.as_str())?
3404 .into_float_value();
3405 self.builder.build_store(inputs_ptr, input_value)?;
3406 } else {
3407 let input_value = self
3408 .build_load(self.real_type, inputs_ptr, name.as_str())?
3409 .into_float_value();
3410 self.builder.build_store(input_ptr, input_value)?;
3411 }
3412
3413 let one = self.int_type.const_int(1, false);
3415 let next_index = self
3416 .builder
3417 .build_int_add(curr_input_index, one, name.as_str())?;
3418 index.add_incoming(&[(&next_index, input_block)]);
3419
3420 let loop_while = self.builder.build_int_compare(
3422 IntPredicate::ULT,
3423 next_index,
3424 end_index,
3425 name.as_str(),
3426 )?;
3427 let post_block = self.context.append_basic_block(function, name.as_str());
3428 self.builder
3429 .build_conditional_branch(loop_while, input_block, post_block)?;
3430 self.builder.position_at_end(post_block);
3431
3432 block = post_block;
3434 inputs_index += input.nnz();
3435 }
3436 self.builder.build_return(None)?;
3437
3438 if function.verify(true) {
3439 Ok(function)
3440 } else {
3441 function.print_to_stderr();
3442 unsafe {
3443 function.delete();
3444 }
3445 Err(anyhow!("Invalid generated function."))
3446 }
3447 }
3448
3449 pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
3450 self.clear();
3451 let void_type = self.context.void_type();
3452 let fn_type = void_type.fn_type(&[self.real_ptr_type.into()], false);
3453 let function = self.module.add_function("set_id", fn_type, None);
3454 let mut block = self.context.append_basic_block(function, "entry");
3455
3456 let fn_arg_names = &["id"];
3457 self.builder.position_at_end(block);
3458
3459 for (i, arg) in function.get_param_iter().enumerate() {
3460 let name = fn_arg_names[i];
3461 let alloca = self.function_arg_alloca(name, arg);
3462 self.insert_param(name, alloca);
3463 }
3464
3465 let mut id_index = 0usize;
3466 for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
3467 let name = blk.name().unwrap_or("unknown");
3468 let id_start_index = self.int_type.const_int(id_index as u64, false);
3470 let blk_start_index = self.int_type.const_int(0, false);
3471 let blk_end_index = self
3472 .int_type
3473 .const_int(blk.nnz().try_into().unwrap(), false);
3474
3475 let blk_block = self.context.append_basic_block(function, name);
3476 self.builder.build_unconditional_branch(blk_block)?;
3477 self.builder.position_at_end(blk_block);
3478 let index = self.builder.build_phi(self.int_type, "i")?;
3479 index.add_incoming(&[(&blk_start_index, block)]);
3480
3481 let curr_blk_index = index.as_basic_value().into_int_value();
3483 let curr_id_index = self
3484 .builder
3485 .build_int_add(id_start_index, curr_blk_index, name)?;
3486 let id_ptr = Self::get_ptr_to_index(
3487 &self.builder,
3488 self.real_type,
3489 self.get_param("id"),
3490 curr_id_index,
3491 name,
3492 );
3493 let is_algebraic_float = if *is_algebraic {
3494 0.0 as RealType
3495 } else {
3496 1.0 as RealType
3497 };
3498 let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
3499 self.builder.build_store(id_ptr, is_algebraic_value)?;
3500
3501 let one = self.int_type.const_int(1, false);
3503 let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
3504 index.add_incoming(&[(&next_index, blk_block)]);
3505
3506 let loop_while = self.builder.build_int_compare(
3508 IntPredicate::ULT,
3509 next_index,
3510 blk_end_index,
3511 name,
3512 )?;
3513 let post_block = self.context.append_basic_block(function, name);
3514 self.builder
3515 .build_conditional_branch(loop_while, blk_block, post_block)?;
3516 self.builder.position_at_end(post_block);
3517
3518 block = post_block;
3520 id_index += blk.nnz();
3521 }
3522 self.builder.build_return(None)?;
3523
3524 if function.verify(true) {
3525 Ok(function)
3526 } else {
3527 function.print_to_stderr();
3528 unsafe {
3529 function.delete();
3530 }
3531 Err(anyhow!("Invalid generated function."))
3532 }
3533 }
3534
3535 pub fn module(&self) -> &Module<'ctx> {
3536 &self.module
3537 }
3538}