1use aliasable::boxed::AliasableBox;
2use anyhow::{anyhow, Result};
3use core::panic;
4use inkwell::attributes::{Attribute, AttributeLoc};
5use inkwell::basic_block::BasicBlock;
6use inkwell::builder::Builder;
7use inkwell::context::{AsContextRef, Context};
8use inkwell::debug_info::AsDIScope;
9use inkwell::debug_info::{DICompileUnit, DIFlags, DIFlagsConstants, DebugInfoBuilder};
10use inkwell::execution_engine::ExecutionEngine;
11use inkwell::intrinsics::Intrinsic;
12use inkwell::module::{Linkage, Module};
13use inkwell::passes::PassBuilderOptions;
14use inkwell::targets::{FileType, InitializationConfig, Target, TargetMachine, TargetTriple};
15use inkwell::types::{
16 BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FloatType, FunctionType, IntType, PointerType,
17};
18use inkwell::values::{
19 AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue, FloatValue,
20 FunctionValue, GlobalValue, IntValue, PointerValue,
21};
22use inkwell::{
23 AddressSpace, AtomicOrdering, AtomicRMWBinOp, FloatPredicate, GlobalVisibility, IntPredicate,
24 OptimizationLevel,
25};
26use llvm_sys::core::{
27 LLVMBuildCall2, LLVMGetArgOperand, LLVMGetBasicBlockParent, LLVMGetGlobalParent,
28 LLVMGetInstructionParent, LLVMGetNamedFunction, LLVMGlobalGetValueType, LLVMIsMultithreaded,
29};
30use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
31use pest::Span;
32use std::collections::HashMap;
33use std::ffi::CString;
34use std::iter::zip;
35use std::pin::Pin;
36use target_lexicon::Triple;
37
38use crate::ast::{Ast, AstKind};
39use crate::discretise::{DiscreteModel, Tensor, TensorBlock};
40use crate::enzyme::{
41 CConcreteType_DT_Anything, CConcreteType_DT_Double, CConcreteType_DT_Float,
42 CConcreteType_DT_Integer, CConcreteType_DT_Pointer, CDerivativeMode_DEM_ForwardMode,
43 CDerivativeMode_DEM_ReverseModeCombined, CFnTypeInfo, CreateEnzymeLogic, CreateTypeAnalysis,
44 DiffeGradientUtils, EnzymeCreateForwardDiff, EnzymeCreatePrimalAndGradient, EnzymeFreeTypeTree,
45 EnzymeGradientUtilsNewFromOriginal, EnzymeLogicRef, EnzymeMergeTypeTree, EnzymeNewTypeTreeCT,
46 EnzymeRegisterCallHandler, EnzymeTypeAnalysisRef, EnzymeTypeTreeOnlyEq, FreeEnzymeLogic,
47 FreeTypeAnalysis, GradientUtils, IntList, LLVMOpaqueContext, CDIFFE_TYPE_DFT_CONSTANT,
48 CDIFFE_TYPE_DFT_DUP_ARG, CDIFFE_TYPE_DFT_DUP_NONEED,
49};
50use crate::execution::compiler::CompilerOptions;
51use crate::execution::module::{
52 CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
53};
54use crate::execution::scalar::RealType;
55use crate::execution::{DataLayout, Translation, TranslationFrom, TranslationTo};
56use lazy_static::lazy_static;
57use std::sync::Mutex;
58
59lazy_static! {
60 static ref my_mutex: Mutex<i32> = Mutex::new(0i32);
61}
62
63struct ImmovableLlvmModule {
64 codegen: Option<CodeGen<'static>>,
67 context: AliasableBox<Context>,
69 _pin: std::marker::PhantomPinned,
70}
71
72pub struct LlvmModule {
73 inner: Pin<Box<ImmovableLlvmModule>>,
74 machine: TargetMachine,
75}
76
77unsafe impl Send for LlvmModule {}
78unsafe impl Sync for LlvmModule {}
79
80impl LlvmModule {
81 fn new(
82 triple: Option<Triple>,
83 model: &DiscreteModel,
84 threaded: bool,
85 real_type: RealType,
86 debug: bool,
87 constants_use_jit: bool,
88 ) -> Result<Self> {
89 let initialization_config = &InitializationConfig::default();
90 Target::initialize_all(initialization_config);
91 let host_triple = Triple::host();
92 let (triple_str, native) = match triple {
93 Some(ref triple) => (triple.to_string(), false),
94 None => (host_triple.to_string(), true),
95 };
96 let triple = TargetTriple::create(triple_str.as_str());
97 let target = Target::from_triple(&triple).unwrap();
98 let cpu = if native {
99 TargetMachine::get_host_cpu_name().to_string()
100 } else {
101 "generic".to_string()
102 };
103 let features = if native {
104 TargetMachine::get_host_cpu_features().to_string()
105 } else {
106 "".to_string()
107 };
108 let machine = target
109 .create_target_machine(
110 &triple,
111 cpu.as_str(),
112 features.as_str(),
113 inkwell::OptimizationLevel::Aggressive,
114 inkwell::targets::RelocMode::PIC,
115 inkwell::targets::CodeModel::Default,
116 )
117 .unwrap();
118
119 let context = AliasableBox::from_unique(Box::new(Context::create()));
120 let mut pinned = Self {
121 inner: Box::pin(ImmovableLlvmModule {
122 codegen: None,
123 context,
124 _pin: std::marker::PhantomPinned,
125 }),
126 machine,
127 };
128
129 let context_ref = pinned.inner.context.as_ref();
130 let real_type_llvm = match real_type {
131 RealType::F32 => context_ref.f32_type(),
132 RealType::F64 => context_ref.f64_type(),
133 };
134 let int_type_llvm = context_ref.i32_type();
135 let ptr_size_bits = pinned
136 .machine
137 .get_target_data()
138 .get_bit_size(&context_ref.ptr_type(AddressSpace::default()));
139 let real_size_bits = pinned
140 .machine
141 .get_target_data()
142 .get_bit_size(&real_type_llvm);
143 let int_size_bits = pinned
144 .machine
145 .get_target_data()
146 .get_bit_size(&int_type_llvm);
147 let codegen = CodeGen::new(
148 model,
149 context_ref,
150 real_type,
151 real_type_llvm,
152 int_type_llvm,
153 threaded,
154 ptr_size_bits,
155 real_size_bits,
156 int_size_bits,
157 debug,
158 constants_use_jit,
159 )?;
160 let codegen = unsafe { std::mem::transmute::<CodeGen<'_>, CodeGen<'static>>(codegen) };
161 unsafe { pinned.inner.as_mut().get_unchecked_mut().codegen = Some(codegen) };
162 Ok(pinned)
163 }
164
165 fn pre_autodiff_optimisation(&mut self) -> Result<()> {
166 let pass_options = PassBuilderOptions::create();
175 pass_options.set_loop_vectorization(false);
179 pass_options.set_loop_slp_vectorization(false);
180 pass_options.set_loop_unrolling(false);
181 let passes = "annotation2metadata,forceattrs,inferattrs,coro-early,function<eager-inv>(lower-expect,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;no-switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,early-cse<>),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg),function<eager-inv>(instcombine,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),require<globals-aa>,function(invalidate<aa>),require<profile-summary>,cgscc(devirt<4>(inline<only-mandatory>,inline,function-attrs,openmp-opt-cgscc,function<eager-inv>(early-cse<memssa>,speculative-execution,jump-threading,correlated-propagation,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,reassociate,require<opt-remark-emit>,loop-mssa(loop-instsimplify,loop-simplifycfg,licm<no-allowspeculation>,loop-rotate,licm<allowspeculation>,simple-loop-unswitch<no-nontrivial;trivial>),simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>,instcombine,loop(loop-idiom,indvars,loop-deletion),vector-combine,mldst-motion<no-split-footer-bb>,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,loop-mssa(licm<allowspeculation>),coro-elide,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;hoist-common-insts;sink-common-insts>,instcombine),coro-split)),deadargelim,coro-cleanup,globalopt,globaldce,elim-avail-extern,rpo-function-attrs,recompute-globalsaa,function<eager-inv>(float2int,lower-constant-intrinsics,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-load-elim,instcombine,simplifycfg<bonus-inst-threshold=1;forward-switch-cond;switch-range-to-icmp;switch-to-lookup;no-keep-loops;hoist-common-insts;sink-common-insts>,vector-combine,instcombine,transform-warning,instcombine,require<opt-remark-emit>,loop-mssa(licm<allowspeculation>),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg<bonus-inst-threshold=1;no-forward-switch-cond;switch-range-to-icmp;no-switch-to-lookup;keep-loops;no-hoist-common-insts;no-sink-common-insts>),globaldce,constmerge,cg-profile,rel-lookup-table-converter,function(annotation-remarks),verify";
194 let (codegen, machine) = self.codegen_and_machine_mut();
195 codegen
196 .module()
197 .run_passes(passes, machine, pass_options)
198 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))
199
200 }
206
207 fn post_autodiff_optimisation(&mut self) -> Result<()> {
208 if let Some(barrier_func) = self.codegen_mut().module().get_function("barrier") {
210 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
211 barrier_func.remove_enum_attribute(AttributeLoc::Function, nolinline_kind_id);
212 }
213
214 for f in self.codegen_mut().module.get_functions() {
216 if f.get_name().to_str().unwrap().starts_with("preprocess_") {
217 unsafe { f.delete() };
218 }
219 }
220
221 let passes = "default<O3>";
227 let (codegen, machine) = self.codegen_and_machine_mut();
228 codegen
229 .module()
230 .run_passes(passes, machine, PassBuilderOptions::create())
231 .map_err(|e| anyhow!("Failed to run passes: {:?}", e))?;
232
233 Ok(())
239 }
240
241 pub fn print(&self) {
242 self.codegen().module().print_to_stderr();
243 }
244 fn codegen_mut(&mut self) -> &mut CodeGen<'static> {
245 unsafe {
246 self.inner
247 .as_mut()
248 .get_unchecked_mut()
249 .codegen
250 .as_mut()
251 .unwrap()
252 }
253 }
254 fn codegen_and_machine_mut(&mut self) -> (&mut CodeGen<'static>, &TargetMachine) {
255 (
256 unsafe {
257 self.inner
258 .as_mut()
259 .get_unchecked_mut()
260 .codegen
261 .as_mut()
262 .unwrap()
263 },
264 &self.machine,
265 )
266 }
267
268 fn codegen(&self) -> &CodeGen<'static> {
269 self.inner.as_ref().get_ref().codegen.as_ref().unwrap()
270 }
271
272 pub fn to_dynamic_library(self, output_path: impl Into<std::path::PathBuf>) -> Result<()> {
273 use std::fs;
274 use std::process::Command;
275
276 let output_path = output_path.into();
277 let object_buffer = self.to_object()?;
278
279 let temp_dir = std::env::temp_dir();
281 let obj_path = temp_dir.join("diffsl_temp_object.o");
282 fs::write(&obj_path, object_buffer)
283 .map_err(|e| anyhow!("Failed to write temporary object file: {}", e))?;
284
285 let lld = option_env!("DIFFSL_LLVM_LLD")
286 .ok_or_else(|| anyhow!("DIFFSL_LLVM_LLD not set by build script"))?;
287 let linker_name = std::path::Path::new(lld)
288 .file_name()
289 .and_then(|name| name.to_str())
290 .unwrap_or(lld);
291 let is_clang_driver = linker_name.starts_with("clang");
292
293 let mut command = Command::new(lld);
294 if cfg!(target_os = "windows") {
295 if is_clang_driver {
296 command.arg("-shared");
297 command.arg("-o");
298 command.arg(&output_path);
299 command.arg(&obj_path);
300 } else {
301 command.arg("-flavor").arg("link");
302 command.arg("/DLL");
303 command.arg(format!("/OUT:{}", output_path.display()));
304 command.arg(&obj_path);
305 }
306 } else if cfg!(target_os = "macos") {
307 if !lld.ends_with("ld64.lld") && !is_clang_driver {
308 command.arg("-flavor").arg("darwin");
309 }
310 let arch = if cfg!(target_arch = "aarch64") {
311 "arm64"
312 } else if cfg!(target_arch = "x86_64") {
313 "x86_64"
314 } else {
315 return Err(anyhow!("Unsupported macOS architecture for lld invocation"));
316 };
317 let deployment_target =
318 std::env::var("MACOSX_DEPLOYMENT_TARGET").unwrap_or_else(|_| "11.0".to_string());
319 if is_clang_driver {
320 command.arg("-dynamiclib");
321 command.arg("-arch");
322 command.arg(arch);
323 command.arg(format!("-mmacosx-version-min={deployment_target}"));
324 command.arg("-o");
325 command.arg(&output_path);
326 command.arg(&obj_path);
327 } else {
328 command.arg("-arch");
329 command.arg(arch);
330 command.arg("-platform_version");
331 command.arg("macos");
332 command.arg(&deployment_target);
333 command.arg(&deployment_target);
334 command.arg("-dylib");
335 command.arg("-o");
336 command.arg(&output_path);
337 command.arg(&obj_path);
338 }
339 } else {
340 if is_clang_driver {
341 command.arg("-shared");
342 command.arg("-o");
343 command.arg(&output_path);
344 command.arg(&obj_path);
345 } else {
346 command.arg("-flavor").arg("gnu");
347 command.arg("-shared");
348 command.arg("-o");
349 command.arg(&output_path);
350 command.arg(&obj_path);
351 }
352 }
353
354 let status = command
355 .status()
356 .map_err(|e| anyhow!("Failed to invoke lld: {}", e))?;
357 if !status.success() {
358 return Err(anyhow!(
359 "Dynamic library link failed with status: {}",
360 status
361 ));
362 }
363
364 let _ = fs::remove_file(&obj_path);
365 Ok(())
366 }
367}
368
369impl CodegenModule for LlvmModule {}
370
371impl CodegenModuleCompile for LlvmModule {
372 fn from_discrete_model(
373 model: &DiscreteModel,
374 options: CompilerOptions,
375 triple: Option<Triple>,
376 real_type: RealType,
377 code: Option<&str>,
378 ) -> Result<Self> {
379 let thread_dim = options.mode.thread_dim(model.state().nnz());
380 let threaded = thread_dim > 1;
381 if (unsafe { LLVMIsMultithreaded() } <= 0) {
382 return Err(anyhow!(
383 "LLVM is not compiled with multithreading support, but this codegen module requires it."
384 ));
385 }
386
387 let constants_use_jit = options.constants_use_jit;
388 let mut module = Self::new(
389 triple,
390 model,
391 threaded,
392 real_type,
393 options.debug,
394 constants_use_jit,
395 )?;
396
397 let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
398 let calc_stop = module.codegen_mut().compile_calc_stop(model, false, code)?;
399 let calc_stop_full = module.codegen_mut().compile_calc_stop(model, true, code)?;
400 let reset = module.codegen_mut().compile_reset(model, false, code)?;
401 let reset_full = module.codegen_mut().compile_reset(model, true, code)?;
402 let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
403 let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
404 let mass = module.codegen_mut().compile_mass(model, code)?;
405 let calc_out = module.codegen_mut().compile_calc_out(model, false, code)?;
406 let calc_out_full = module.codegen_mut().compile_calc_out(model, true, code)?;
407 let _set_id = module.codegen_mut().compile_set_id(model)?;
408 let _get_dims = module.codegen_mut().compile_get_dims(model)?;
409 let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
410 let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
411 let _set_constants =
412 module
413 .codegen_mut()
414 .compile_set_constants(model, code, constants_use_jit)?;
415 let tensor_info = module
416 .codegen()
417 .layout
418 .tensors()
419 .map(|(name, is_constant)| (name.to_string(), is_constant))
420 .collect::<Vec<_>>();
421 for (tensor, is_constant) in tensor_info {
422 if is_constant {
423 module
424 .codegen_mut()
425 .compile_get_constant(model, tensor.as_str())?;
426 } else {
427 module
428 .codegen_mut()
429 .compile_get_tensor(model, tensor.as_str())?;
430 }
431 }
432
433 module.pre_autodiff_optimisation()?;
434
435 module.codegen_mut().compile_gradient(
436 set_u0,
437 &[
438 CompileGradientArgType::DupNoNeed,
439 CompileGradientArgType::DupNoNeed,
440 CompileGradientArgType::Const,
441 CompileGradientArgType::Const,
442 ],
443 CompileMode::Forward,
444 "set_u0_grad",
445 )?;
446
447 module.codegen_mut().compile_gradient(
448 rhs,
449 &[
450 CompileGradientArgType::Const,
451 CompileGradientArgType::DupNoNeed,
452 CompileGradientArgType::DupNoNeed,
453 CompileGradientArgType::DupNoNeed,
454 CompileGradientArgType::Const,
455 CompileGradientArgType::Const,
456 ],
457 CompileMode::Forward,
458 "rhs_grad",
459 )?;
460
461 module.codegen_mut().compile_gradient(
462 reset,
463 &[
464 CompileGradientArgType::Const,
465 CompileGradientArgType::DupNoNeed,
466 CompileGradientArgType::DupNoNeed,
467 CompileGradientArgType::DupNoNeed,
468 CompileGradientArgType::Const,
469 CompileGradientArgType::Const,
470 ],
471 CompileMode::Forward,
472 "reset_grad",
473 )?;
474
475 module.codegen_mut().compile_gradient(
476 calc_stop,
477 &[
478 CompileGradientArgType::Const,
479 CompileGradientArgType::DupNoNeed,
480 CompileGradientArgType::DupNoNeed,
481 CompileGradientArgType::DupNoNeed,
482 CompileGradientArgType::Const,
483 CompileGradientArgType::Const,
484 ],
485 CompileMode::Forward,
486 "calc_stop_grad",
487 )?;
488
489 module.codegen_mut().compile_gradient(
490 calc_out,
491 &[
492 CompileGradientArgType::Const,
493 CompileGradientArgType::DupNoNeed,
494 CompileGradientArgType::DupNoNeed,
495 CompileGradientArgType::DupNoNeed,
496 CompileGradientArgType::Const,
497 CompileGradientArgType::Const,
498 ],
499 CompileMode::Forward,
500 "calc_out_grad",
501 )?;
502 module.codegen_mut().compile_gradient(
503 set_inputs,
504 &[
505 CompileGradientArgType::DupNoNeed,
506 CompileGradientArgType::DupNoNeed,
507 CompileGradientArgType::Const,
508 ],
509 CompileMode::Forward,
510 "set_inputs_grad",
511 )?;
512
513 module.codegen_mut().compile_gradient(
514 set_u0,
515 &[
516 CompileGradientArgType::DupNoNeed,
517 CompileGradientArgType::DupNoNeed,
518 CompileGradientArgType::Const,
519 CompileGradientArgType::Const,
520 ],
521 CompileMode::Reverse,
522 "set_u0_rgrad",
523 )?;
524
525 module.codegen_mut().compile_gradient(
526 mass,
527 &[
528 CompileGradientArgType::Const,
529 CompileGradientArgType::DupNoNeed,
530 CompileGradientArgType::DupNoNeed,
531 CompileGradientArgType::DupNoNeed,
532 CompileGradientArgType::Const,
533 CompileGradientArgType::Const,
534 ],
535 CompileMode::Reverse,
536 "mass_rgrad",
537 )?;
538
539 module.codegen_mut().compile_gradient(
540 rhs,
541 &[
542 CompileGradientArgType::Const,
543 CompileGradientArgType::DupNoNeed,
544 CompileGradientArgType::DupNoNeed,
545 CompileGradientArgType::DupNoNeed,
546 CompileGradientArgType::Const,
547 CompileGradientArgType::Const,
548 ],
549 CompileMode::Reverse,
550 "rhs_rgrad",
551 )?;
552 module.codegen_mut().compile_gradient(
553 reset,
554 &[
555 CompileGradientArgType::Const,
556 CompileGradientArgType::DupNoNeed,
557 CompileGradientArgType::DupNoNeed,
558 CompileGradientArgType::DupNoNeed,
559 CompileGradientArgType::Const,
560 CompileGradientArgType::Const,
561 ],
562 CompileMode::Reverse,
563 "reset_rgrad",
564 )?;
565 module.codegen_mut().compile_gradient(
566 calc_stop,
567 &[
568 CompileGradientArgType::Const,
569 CompileGradientArgType::DupNoNeed,
570 CompileGradientArgType::DupNoNeed,
571 CompileGradientArgType::DupNoNeed,
572 CompileGradientArgType::Const,
573 CompileGradientArgType::Const,
574 ],
575 CompileMode::Reverse,
576 "calc_stop_rgrad",
577 )?;
578 module.codegen_mut().compile_gradient(
579 calc_out,
580 &[
581 CompileGradientArgType::Const,
582 CompileGradientArgType::DupNoNeed,
583 CompileGradientArgType::DupNoNeed,
584 CompileGradientArgType::DupNoNeed,
585 CompileGradientArgType::Const,
586 CompileGradientArgType::Const,
587 ],
588 CompileMode::Reverse,
589 "calc_out_rgrad",
590 )?;
591
592 module.codegen_mut().compile_gradient(
593 set_inputs,
594 &[
595 CompileGradientArgType::DupNoNeed,
596 CompileGradientArgType::DupNoNeed,
597 CompileGradientArgType::Const,
598 ],
599 CompileMode::Reverse,
600 "set_inputs_rgrad",
601 )?;
602
603 module.codegen_mut().compile_gradient(
604 rhs_full,
605 &[
606 CompileGradientArgType::Const,
607 CompileGradientArgType::Const,
608 CompileGradientArgType::DupNoNeed,
609 CompileGradientArgType::DupNoNeed,
610 CompileGradientArgType::Const,
611 CompileGradientArgType::Const,
612 ],
613 CompileMode::ForwardSens,
614 "rhs_sgrad",
615 )?;
616
617 module.codegen_mut().compile_gradient(
618 reset_full,
619 &[
620 CompileGradientArgType::Const,
621 CompileGradientArgType::Const,
622 CompileGradientArgType::DupNoNeed,
623 CompileGradientArgType::DupNoNeed,
624 CompileGradientArgType::Const,
625 CompileGradientArgType::Const,
626 ],
627 CompileMode::ForwardSens,
628 "reset_sgrad",
629 )?;
630
631 module.codegen_mut().compile_gradient(
632 calc_stop_full,
633 &[
634 CompileGradientArgType::Const,
635 CompileGradientArgType::Const,
636 CompileGradientArgType::DupNoNeed,
637 CompileGradientArgType::DupNoNeed,
638 CompileGradientArgType::Const,
639 CompileGradientArgType::Const,
640 ],
641 CompileMode::ForwardSens,
642 "calc_stop_sgrad",
643 )?;
644
645 module.codegen_mut().compile_gradient(
646 set_u0,
647 &[
648 CompileGradientArgType::DupNoNeed,
649 CompileGradientArgType::DupNoNeed,
650 CompileGradientArgType::Const,
651 CompileGradientArgType::Const,
652 ],
653 CompileMode::ForwardSens,
654 "set_u0_sgrad",
655 )?;
656
657 module.codegen_mut().compile_gradient(
658 calc_out_full,
659 &[
660 CompileGradientArgType::Const,
661 CompileGradientArgType::Const,
662 CompileGradientArgType::DupNoNeed,
663 CompileGradientArgType::DupNoNeed,
664 CompileGradientArgType::Const,
665 CompileGradientArgType::Const,
666 ],
667 CompileMode::ForwardSens,
668 "calc_out_sgrad",
669 )?;
670 module.codegen_mut().compile_gradient(
671 calc_out_full,
672 &[
673 CompileGradientArgType::Const,
674 CompileGradientArgType::Const,
675 CompileGradientArgType::DupNoNeed,
676 CompileGradientArgType::DupNoNeed,
677 CompileGradientArgType::Const,
678 CompileGradientArgType::Const,
679 ],
680 CompileMode::ReverseSens,
681 "calc_out_srgrad",
682 )?;
683
684 module.codegen_mut().compile_gradient(
685 rhs_full,
686 &[
687 CompileGradientArgType::Const,
688 CompileGradientArgType::Const,
689 CompileGradientArgType::DupNoNeed,
690 CompileGradientArgType::DupNoNeed,
691 CompileGradientArgType::Const,
692 CompileGradientArgType::Const,
693 ],
694 CompileMode::ReverseSens,
695 "rhs_srgrad",
696 )?;
697
698 module.codegen_mut().compile_gradient(
699 reset_full,
700 &[
701 CompileGradientArgType::Const,
702 CompileGradientArgType::Const,
703 CompileGradientArgType::DupNoNeed,
704 CompileGradientArgType::DupNoNeed,
705 CompileGradientArgType::Const,
706 CompileGradientArgType::Const,
707 ],
708 CompileMode::ReverseSens,
709 "reset_srgrad",
710 )?;
711
712 module.codegen_mut().compile_gradient(
713 calc_stop_full,
714 &[
715 CompileGradientArgType::Const,
716 CompileGradientArgType::Const,
717 CompileGradientArgType::DupNoNeed,
718 CompileGradientArgType::DupNoNeed,
719 CompileGradientArgType::Const,
720 CompileGradientArgType::Const,
721 ],
722 CompileMode::ReverseSens,
723 "calc_stop_srgrad",
724 )?;
725
726 module.post_autodiff_optimisation()?;
727 Ok(module)
728 }
729}
730
731impl CodegenModuleEmit for LlvmModule {
732 fn to_object(&self) -> Result<Vec<u8>> {
733 let module = self.codegen().module();
734 let buffer = self
736 .machine
737 .write_to_memory_buffer(module, FileType::Object)
738 .unwrap()
739 .as_slice()
740 .to_vec();
741 Ok(buffer)
742 }
743}
744impl CodegenModuleJit for LlvmModule {
745 fn jit(&mut self) -> Result<HashMap<String, *const u8>> {
746 let ee = self
747 .codegen()
748 .module()
749 .create_jit_execution_engine(OptimizationLevel::Default)
750 .map_err(|e| anyhow!("Failed to create JIT execution engine: {:?}", e))?;
751
752 let module = self.codegen().module();
758 let mut symbols = HashMap::new();
759 for function in module.get_functions() {
760 let name = function.get_name().to_str().unwrap();
761 let address = ee.get_function_address(name);
762 if let Ok(address) = address {
763 symbols.insert(name.to_string(), address as *const u8);
764 }
765 }
766 Ok(symbols)
767 }
768}
769
770struct Globals<'ctx> {
771 indices: Option<GlobalValue<'ctx>>,
772 constants: Option<GlobalValue<'ctx>>,
773 thread_counter: Option<GlobalValue<'ctx>>,
774 model_index: GlobalValue<'ctx>,
775}
776
777impl<'ctx> Globals<'ctx> {
778 fn new(
779 layout: &DataLayout,
780 module: &Module<'ctx>,
781 int_type: IntType<'ctx>,
782 real_type: FloatType<'ctx>,
783 threaded: bool,
784 constants_use_jit: bool,
785 ) -> Self {
786 let thread_counter = if threaded {
787 let tc = module.add_global(
788 int_type,
789 Some(AddressSpace::default()),
790 "enzyme_const_thread_counter",
791 );
792 let tc_value = int_type.const_zero();
799 tc.set_visibility(GlobalVisibility::Hidden);
800 tc.set_initializer(&tc_value.as_basic_value_enum());
801 Some(tc)
802 } else {
803 None
804 };
805 let constants = if layout.constants().is_empty() {
806 None
807 } else {
808 let constants_array_type =
809 real_type.array_type(u32::try_from(layout.constants().len()).unwrap());
810 let constants = module.add_global(
811 constants_array_type,
812 Some(AddressSpace::default()),
813 "enzyme_const_constants",
814 );
815 constants.set_visibility(GlobalVisibility::Hidden);
816 constants.set_constant(!constants_use_jit);
817 let constants_array_values = layout
818 .constants()
819 .iter()
820 .map(|&value| real_type.const_float(value))
821 .collect::<Vec<_>>();
822 let constants_value = real_type.const_array(constants_array_values.as_slice());
823 constants.set_initializer(&constants_value);
824 Some(constants)
825 };
826 let indices = if layout.indices().is_empty() {
827 None
828 } else {
829 let indices_array_type =
830 int_type.array_type(u32::try_from(layout.indices().len()).unwrap());
831 let indices_array_values = layout
832 .indices()
833 .iter()
834 .map(|&i| int_type.const_int(i as u64, true))
835 .collect::<Vec<IntValue>>();
836 let indices_value = int_type.const_array(indices_array_values.as_slice());
837 let indices = module.add_global(
838 indices_array_type,
839 Some(AddressSpace::default()),
840 "enzyme_const_indices",
841 );
842 indices.set_constant(true);
843 indices.set_visibility(GlobalVisibility::Hidden);
844 indices.set_initializer(&indices_value);
845 Some(indices)
846 };
847 let model_index = module.add_global(
848 int_type,
849 Some(AddressSpace::default()),
850 "enzyme_const_model_index",
851 );
852 model_index.set_visibility(GlobalVisibility::Hidden);
853 model_index.set_constant(false);
854 model_index.set_initializer(&int_type.const_zero());
855 Self {
856 indices,
857 thread_counter,
858 constants,
859 model_index,
860 }
861 }
862}
863
864pub enum CompileGradientArgType {
865 Const,
866 Dup,
867 DupNoNeed,
868}
869
870pub enum CompileMode {
871 Forward,
872 ForwardSens,
873 Reverse,
874 ReverseSens,
875}
876
877pub struct CodeGen<'ctx> {
878 context: &'ctx inkwell::context::Context,
879 module: Module<'ctx>,
880 builder: Builder<'ctx>,
881 dibuilder: Option<DebugInfoBuilder<'ctx>>,
882 compile_unit: Option<DICompileUnit<'ctx>>,
883 variables: HashMap<String, PointerValue<'ctx>>,
884 functions: HashMap<String, FunctionValue<'ctx>>,
885 fn_value_opt: Option<FunctionValue<'ctx>>,
886 tensor_ptr_opt: Option<PointerValue<'ctx>>,
887 diffsl_real_type: RealType,
888 real_type: FloatType<'ctx>,
889 real_ptr_type: PointerType<'ctx>,
890 int_type: IntType<'ctx>,
891 int_ptr_type: PointerType<'ctx>,
892 ptr_size_bits: u64,
893 int_size_bits: u64,
894 real_size_bits: u64,
895 layout: DataLayout,
896 globals: Globals<'ctx>,
897 threaded: bool,
898 _ee: Option<ExecutionEngine<'ctx>>,
899}
900
901unsafe extern "C" fn fwd_handler(
902 _builder: LLVMBuilderRef,
903 _call_instruction: LLVMValueRef,
904 _gutils: *mut GradientUtils,
905 _dcall: *mut LLVMValueRef,
906 _normal_return: *mut LLVMValueRef,
907 _shadow_return: *mut LLVMValueRef,
908) -> u8 {
909 1
910}
911
912unsafe extern "C" fn rev_handler(
913 builder: LLVMBuilderRef,
914 call_instruction: LLVMValueRef,
915 gutils: *mut DiffeGradientUtils,
916 _tape: LLVMValueRef,
917) {
918 let call_block = LLVMGetInstructionParent(call_instruction);
919 let call_function = LLVMGetBasicBlockParent(call_block);
920 let module = LLVMGetGlobalParent(call_function);
921 let name_c_str = CString::new("barrier_grad").unwrap();
922 let barrier_func = LLVMGetNamedFunction(module, name_c_str.as_ptr());
923 let barrier_func_type = LLVMGlobalGetValueType(barrier_func);
924 let barrier_num = LLVMGetArgOperand(call_instruction, 0);
925 let total_barriers = LLVMGetArgOperand(call_instruction, 1);
926 let thread_count = LLVMGetArgOperand(call_instruction, 2);
927 let barrier_num = EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, barrier_num);
928 let total_barriers =
929 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, total_barriers);
930 let thread_count =
931 EnzymeGradientUtilsNewFromOriginal(gutils as *mut GradientUtils, thread_count);
932 let mut args = [barrier_num, total_barriers, thread_count];
933 let name_c_str = CString::new("").unwrap();
934 LLVMBuildCall2(
935 builder,
936 barrier_func_type,
937 barrier_func,
938 args.as_mut_ptr(),
939 args.len() as u32,
940 name_c_str.as_ptr(),
941 );
942}
943
944#[allow(dead_code)]
945enum PrintValue<'ctx> {
946 Real(FloatValue<'ctx>),
947 Int(IntValue<'ctx>),
948}
949
950impl<'ctx> CodeGen<'ctx> {
951 #[allow(clippy::too_many_arguments)]
952 pub fn new(
953 model: &DiscreteModel,
954 context: &'ctx inkwell::context::Context,
955 diffsl_real_type: RealType,
956 real_type: FloatType<'ctx>,
957 int_type: IntType<'ctx>,
958 threaded: bool,
959 ptr_size_bits: u64,
960 real_size_bits: u64,
961 int_size_bits: u64,
962 debug: bool,
963 constants_use_jit: bool,
964 ) -> Result<Self> {
965 let builder = context.create_builder();
966 let layout = DataLayout::new(model);
967 let module = context.create_module(model.name());
968 let (dibuilder, compile_unit) = if debug {
969 let (dib, dic) = module.create_debug_info_builder(
970 true,
971 inkwell::debug_info::DWARFSourceLanguage::C,
972 model.name(),
973 ".",
974 "diffsl compiler",
975 true,
976 "",
977 0,
978 "",
979 inkwell::debug_info::DWARFEmissionKind::Full,
980 0,
981 false,
982 false,
983 "",
984 "",
985 );
986 (Some(dib), Some(dic))
987 } else {
988 (None, None)
989 };
990 let globals = Globals::new(
991 &layout,
992 &module,
993 int_type,
994 real_type,
995 threaded,
996 constants_use_jit,
997 );
998 let real_ptr_type = Self::pointer_type(context, real_type.into());
999 let int_ptr_type = Self::pointer_type(context, int_type.into());
1000 let mut ret = Self {
1001 context,
1002 module,
1003 builder,
1004 dibuilder,
1005 compile_unit,
1006 real_type,
1007 real_ptr_type,
1008 variables: HashMap::new(),
1009 functions: HashMap::new(),
1010 fn_value_opt: None,
1011 tensor_ptr_opt: None,
1012 layout,
1013 diffsl_real_type,
1014 int_type,
1015 int_ptr_type,
1016 globals,
1017 threaded,
1018 ptr_size_bits,
1019 real_size_bits,
1020 int_size_bits,
1021 _ee: None,
1022 };
1023 if threaded {
1024 ret.compile_barrier_init()?;
1025 ret.compile_barrier()?;
1026 ret.compile_barrier_grad()?;
1027 }
1030 Ok(ret)
1031 }
1032
1033 fn start_function(
1034 &mut self,
1035 function: FunctionValue<'ctx>,
1036 _code: Option<&str>,
1037 ) -> BasicBlock<'ctx> {
1038 let basic_block = self.context.append_basic_block(function, "entry");
1039 self.fn_value_opt = Some(function);
1040 self.builder.position_at_end(basic_block);
1041
1042 if let Some(dibuilder) = &self.dibuilder {
1043 let scope = self
1044 .fn_value_opt
1045 .unwrap()
1046 .get_subprogram()
1047 .unwrap()
1048 .as_debug_info_scope();
1049 let loc = dibuilder.create_debug_location(self.context, 0, 0, scope, None);
1050 self.builder.set_current_debug_location(loc);
1051 }
1052 basic_block
1053 }
1054
1055 fn add_function(
1056 &mut self,
1057 name: &str,
1058 arg_names: &[&str],
1059 arg_types: &[BasicMetadataTypeEnum<'ctx>],
1060 linkage: Option<Linkage>,
1061 is_real_return: bool,
1062 ) -> FunctionValue<'ctx> {
1063 let function_type = if is_real_return {
1064 self.real_type.fn_type(arg_types, false)
1065 } else {
1066 self.context.void_type().fn_type(arg_types, false)
1067 };
1068 let function = self.module.add_function(name, function_type, linkage);
1069 if let (Some(dibuilder), Some(compile_unit)) = (&self.dibuilder, &self.compile_unit) {
1070 let ditypes = arg_names
1071 .iter()
1072 .zip(arg_types.iter())
1073 .map(|(&name, &ty)| {
1074 let size_in_bits = if ty.is_float_type() {
1075 self.real_size_bits
1076 } else if ty.is_int_type() {
1077 self.int_size_bits
1078 } else if ty.is_pointer_type() {
1079 self.ptr_size_bits
1080 } else {
1081 unreachable!("Unsupported argument type for debug info")
1082 };
1083 dibuilder
1084 .create_basic_type(
1085 name,
1086 size_in_bits,
1087 0x00,
1088 <DIFlags as DIFlagsConstants>::PUBLIC,
1089 )
1090 .unwrap()
1091 .as_type()
1092 })
1093 .collect::<Vec<_>>();
1094 let subroutine_type = dibuilder.create_subroutine_type(
1095 compile_unit.get_file(),
1096 None,
1097 &ditypes,
1098 <DIFlags as DIFlagsConstants>::PUBLIC,
1099 );
1100 let func_scope = dibuilder.create_function(
1101 compile_unit.as_debug_info_scope(),
1102 name,
1103 None,
1104 compile_unit.get_file(),
1105 0,
1106 subroutine_type,
1107 true,
1108 true,
1109 0,
1110 <DIFlags as DIFlagsConstants>::PUBLIC,
1111 true,
1112 );
1113 function.set_subprogram(func_scope);
1114 }
1115 self.functions.insert(name.to_owned(), function);
1116 function
1117 }
1118
1119 #[allow(dead_code)]
1120 fn compile_print_value(
1121 &mut self,
1122 name: &str,
1123 value: PrintValue<'ctx>,
1124 ) -> Result<CallSiteValue<'_>> {
1125 let printf = match self.module.get_function("printf") {
1127 Some(f) => f,
1128 None => self.add_function(
1130 "printf",
1131 &["format"],
1132 &[self.int_ptr_type.into()],
1133 Some(Linkage::External),
1134 false,
1135 ),
1136 };
1137 let (format_str, format_str_name) = match value {
1138 PrintValue::Real(_) => (format!("{name}: %f\n"), format!("real_format_{name}")),
1139 PrintValue::Int(_) => (format!("{name}: %d\n"), format!("int_format_{name}")),
1140 };
1141 let format_str = CString::new(format_str).unwrap();
1143 let format_str_global = match self.module.get_global(format_str_name.as_str()) {
1145 Some(g) => g,
1146 None => {
1147 let format_str = self.context.const_string(format_str.as_bytes(), true);
1148 let fmt_str =
1149 self.module
1150 .add_global(format_str.get_type(), None, format_str_name.as_str());
1151 fmt_str.set_initializer(&format_str);
1152 fmt_str.set_visibility(GlobalVisibility::Hidden);
1153 fmt_str
1154 }
1155 };
1156 let format_str_ptr = self.builder.build_pointer_cast(
1158 format_str_global.as_pointer_value(),
1159 self.int_ptr_type,
1160 "format_str_ptr",
1161 )?;
1162 let value: BasicMetadataValueEnum = match value {
1163 PrintValue::Real(v) => v.into(),
1164 PrintValue::Int(v) => v.into(),
1165 };
1166 self.builder
1167 .build_call(printf, &[format_str_ptr.into(), value], "printf_call")
1168 .map_err(|e| anyhow!("Error building call to printf: {}", e))
1169 }
1170
1171 fn compile_set_constants(
1172 &mut self,
1173 model: &DiscreteModel,
1174 code: Option<&str>,
1175 constants_use_jit: bool,
1176 ) -> Result<FunctionValue<'ctx>> {
1177 self.clear();
1178 let fn_arg_names = &["thread_id", "thread_dim"];
1179 let function = self.add_function(
1180 "set_constants",
1181 fn_arg_names,
1182 &[self.int_type.into(), self.int_type.into()],
1183 None,
1184 false,
1185 );
1186 let _basic_block = self.start_function(function, code);
1187
1188 for (i, arg) in function.get_param_iter().enumerate() {
1189 let name = fn_arg_names[i];
1190 let alloca = self.function_arg_alloca(name, arg);
1191 self.insert_param(name, alloca);
1192 }
1193
1194 self.insert_indices();
1195 self.insert_constants(model);
1196
1197 if constants_use_jit {
1198 let mut nbarriers = 0;
1199 let total_barriers = (model.constant_defns().len()) as u64;
1200 let total_barriers_val = self.int_type.const_int(total_barriers, false);
1201 #[allow(clippy::explicit_counter_loop)]
1202 for a in model.constant_defns() {
1203 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
1204 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
1205 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
1206 nbarriers += 1;
1207 }
1208 }
1209
1210 self.builder.build_return(None)?;
1211
1212 if function.verify(true) {
1213 Ok(function)
1214 } else {
1215 function.print_to_stderr();
1216 self.module.print_to_stderr();
1217 unsafe {
1218 function.delete();
1219 }
1220 Err(anyhow!("Invalid generated function."))
1221 }
1222 }
1223
1224 fn compile_barrier_init(&mut self) -> Result<FunctionValue<'ctx>> {
1225 self.clear();
1226 let function = self.add_function("barrier_init", &[], &[], None, false);
1227 let _entry_block = self.start_function(function, None);
1228
1229 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1230 self.builder
1231 .build_store(thread_counter, self.int_type.const_zero())?;
1232
1233 self.builder.build_return(None)?;
1234
1235 if function.verify(true) {
1236 Ok(function)
1237 } else {
1238 function.print_to_stderr();
1239 unsafe {
1240 function.delete();
1241 }
1242 Err(anyhow!("Invalid generated function."))
1243 }
1244 }
1245
1246 fn compile_barrier(&mut self) -> Result<FunctionValue<'ctx>> {
1247 self.clear();
1248 let function = self.add_function(
1249 "barrier",
1250 &["barrier_num", "total_barriers", "thread_count"],
1251 &[
1252 self.int_type.into(),
1253 self.int_type.into(),
1254 self.int_type.into(),
1255 ],
1256 None,
1257 false,
1258 );
1259 let nolinline_kind_id = Attribute::get_named_enum_kind_id("noinline");
1260 let noinline = self.context.create_enum_attribute(nolinline_kind_id, 0);
1261 function.add_attribute(AttributeLoc::Function, noinline);
1262
1263 let _entry_block = self.start_function(function, None);
1264 let increment_block = self.context.append_basic_block(function, "increment");
1265 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1266 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1267
1268 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1269 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1270 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1271 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1272
1273 let nbarrier_equals_total_barriers = self
1274 .builder
1275 .build_int_compare(
1276 IntPredicate::EQ,
1277 barrier_num,
1278 total_barriers,
1279 "nbarrier_equals_total_barriers",
1280 )
1281 .unwrap();
1282 self.builder.build_conditional_branch(
1284 nbarrier_equals_total_barriers,
1285 barrier_done_block,
1286 increment_block,
1287 )?;
1288 self.builder.position_at_end(increment_block);
1289
1290 let barrier_num_times_thread_count = self
1291 .builder
1292 .build_int_mul(barrier_num, thread_count, "barrier_num_times_thread_count")
1293 .unwrap();
1294
1295 let i32_type = self.context.i32_type();
1297 let one = i32_type.const_int(1, false);
1298 self.builder.build_atomicrmw(
1299 AtomicRMWBinOp::Add,
1300 thread_counter,
1301 one,
1302 AtomicOrdering::Monotonic,
1303 )?;
1304
1305 self.builder.build_unconditional_branch(wait_loop_block)?;
1307 self.builder.position_at_end(wait_loop_block);
1308
1309 let current_value = self
1310 .builder
1311 .build_load(i32_type, thread_counter, "current_value")?
1312 .into_int_value();
1313
1314 current_value
1315 .as_instruction_value()
1316 .unwrap()
1317 .set_atomic_ordering(AtomicOrdering::Monotonic)
1318 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1319
1320 let all_threads_done = self.builder.build_int_compare(
1321 IntPredicate::UGE,
1322 current_value,
1323 barrier_num_times_thread_count,
1324 "all_threads_done",
1325 )?;
1326
1327 self.builder.build_conditional_branch(
1328 all_threads_done,
1329 barrier_done_block,
1330 wait_loop_block,
1331 )?;
1332 self.builder.position_at_end(barrier_done_block);
1333
1334 self.builder.build_return(None)?;
1335
1336 if function.verify(true) {
1337 Ok(function)
1338 } else {
1339 function.print_to_stderr();
1340 unsafe {
1341 function.delete();
1342 }
1343 Err(anyhow!("Invalid generated function."))
1344 }
1345 }
1346
1347 fn compile_barrier_grad(&mut self) -> Result<FunctionValue<'ctx>> {
1348 self.clear();
1349 let function = self.add_function(
1350 "barrier_grad",
1351 &["barrier_num", "total_barriers", "thread_count"],
1352 &[
1353 self.int_type.into(),
1354 self.int_type.into(),
1355 self.int_type.into(),
1356 ],
1357 None,
1358 false,
1359 );
1360
1361 let _entry_block = self.start_function(function, None);
1362 let wait_loop_block = self.context.append_basic_block(function, "wait_loop");
1363 let barrier_done_block = self.context.append_basic_block(function, "barrier_done");
1364
1365 let thread_counter = self.globals.thread_counter.unwrap().as_pointer_value();
1366 let barrier_num = function.get_nth_param(0).unwrap().into_int_value();
1367 let total_barriers = function.get_nth_param(1).unwrap().into_int_value();
1368 let thread_count = function.get_nth_param(2).unwrap().into_int_value();
1369
1370 let twice_total_barriers = self
1371 .builder
1372 .build_int_mul(
1373 total_barriers,
1374 self.int_type.const_int(2, false),
1375 "twice_total_barriers",
1376 )
1377 .unwrap();
1378 let twice_total_barriers_minus_barrier_num = self
1379 .builder
1380 .build_int_sub(
1381 twice_total_barriers,
1382 barrier_num,
1383 "twice_total_barriers_minus_barrier_num",
1384 )
1385 .unwrap();
1386 let twice_total_barriers_minus_barrier_num_times_thread_count = self
1387 .builder
1388 .build_int_mul(
1389 twice_total_barriers_minus_barrier_num,
1390 thread_count,
1391 "twice_total_barriers_minus_barrier_num_times_thread_count",
1392 )
1393 .unwrap();
1394
1395 let i32_type = self.context.i32_type();
1397 let one = i32_type.const_int(1, false);
1398 self.builder.build_atomicrmw(
1399 AtomicRMWBinOp::Add,
1400 thread_counter,
1401 one,
1402 AtomicOrdering::Monotonic,
1403 )?;
1404
1405 self.builder.build_unconditional_branch(wait_loop_block)?;
1407 self.builder.position_at_end(wait_loop_block);
1408
1409 let current_value = self
1410 .builder
1411 .build_load(i32_type, thread_counter, "current_value")?
1412 .into_int_value();
1413 current_value
1414 .as_instruction_value()
1415 .unwrap()
1416 .set_atomic_ordering(AtomicOrdering::Monotonic)
1417 .map_err(|e| anyhow!("Error setting atomic ordering: {:?}", e))?;
1418
1419 let all_threads_done = self.builder.build_int_compare(
1420 IntPredicate::UGE,
1421 current_value,
1422 twice_total_barriers_minus_barrier_num_times_thread_count,
1423 "all_threads_done",
1424 )?;
1425
1426 self.builder.build_conditional_branch(
1427 all_threads_done,
1428 barrier_done_block,
1429 wait_loop_block,
1430 )?;
1431 self.builder.position_at_end(barrier_done_block);
1432
1433 self.builder.build_return(None)?;
1434
1435 if function.verify(true) {
1436 Ok(function)
1437 } else {
1438 function.print_to_stderr();
1439 unsafe {
1440 function.delete();
1441 }
1442 Err(anyhow!("Invalid generated function."))
1443 }
1444 }
1445
1446 fn jit_compile_call_barrier(
1447 &mut self,
1448 barrier_num: IntValue<'ctx>,
1449 total_barriers: IntValue<'ctx>,
1450 ) {
1451 if !self.threaded {
1452 return;
1453 }
1454 let thread_dim = self.get_param("thread_dim");
1455 let thread_dim = self
1456 .builder
1457 .build_load(self.int_type, *thread_dim, "thread_dim")
1458 .unwrap()
1459 .into_int_value();
1460 let barrier = self.get_function("barrier").unwrap();
1461 self.builder
1462 .build_call(
1463 barrier,
1464 &[
1465 BasicMetadataValueEnum::IntValue(barrier_num),
1466 BasicMetadataValueEnum::IntValue(total_barriers),
1467 BasicMetadataValueEnum::IntValue(thread_dim),
1468 ],
1469 "barrier",
1470 )
1471 .unwrap();
1472 }
1473
1474 fn jit_threading_limits(
1475 &mut self,
1476 size: IntValue<'ctx>,
1477 ) -> Result<(
1478 IntValue<'ctx>,
1479 IntValue<'ctx>,
1480 BasicBlock<'ctx>,
1481 BasicBlock<'ctx>,
1482 )> {
1483 let one = self.int_type.const_int(1, false);
1484 let thread_id = self.get_param("thread_id");
1485 let thread_id = self
1486 .builder
1487 .build_load(self.int_type, *thread_id, "thread_id")
1488 .unwrap()
1489 .into_int_value();
1490 let thread_dim = self.get_param("thread_dim");
1491 let thread_dim = self
1492 .builder
1493 .build_load(self.int_type, *thread_dim, "thread_dim")
1494 .unwrap()
1495 .into_int_value();
1496
1497 let i_times_size = self
1499 .builder
1500 .build_int_mul(thread_id, size, "i_times_size")?;
1501 let start = self
1502 .builder
1503 .build_int_unsigned_div(i_times_size, thread_dim, "start")?;
1504
1505 let i_plus_one = self.builder.build_int_add(thread_id, one, "i_plus_one")?;
1507 let i_plus_one_times_size =
1508 self.builder
1509 .build_int_mul(i_plus_one, size, "i_plus_one_times_size")?;
1510 let end = self
1511 .builder
1512 .build_int_unsigned_div(i_plus_one_times_size, thread_dim, "end")?;
1513
1514 let test_done = self.builder.get_insert_block().unwrap();
1515 let next_block = self
1516 .context
1517 .append_basic_block(self.fn_value_opt.unwrap(), "threading_block");
1518 self.builder.position_at_end(next_block);
1519
1520 Ok((start, end, test_done, next_block))
1521 }
1522
1523 fn jit_end_threading(
1524 &mut self,
1525 start: IntValue<'ctx>,
1526 end: IntValue<'ctx>,
1527 test_done: BasicBlock<'ctx>,
1528 next: BasicBlock<'ctx>,
1529 ) -> Result<()> {
1530 let exit = self
1531 .context
1532 .append_basic_block(self.fn_value_opt.unwrap(), "exit");
1533 self.builder.build_unconditional_branch(exit)?;
1534 self.builder.position_at_end(test_done);
1535 let done = self
1537 .builder
1538 .build_int_compare(IntPredicate::EQ, start, end, "done")?;
1539 self.builder.build_conditional_branch(done, exit, next)?;
1540 self.builder.position_at_end(exit);
1541 Ok(())
1542 }
1543
1544 pub fn write_bitcode_to_path(&self, path: &std::path::Path) {
1545 self.module.write_bitcode_to_path(path);
1546 }
1547
1548 fn insert_constants(&mut self, model: &DiscreteModel) {
1549 if let Some(constants) = self.globals.constants.as_ref() {
1550 self.insert_param("constants", constants.as_pointer_value());
1551 for tensor in model.constant_defns() {
1552 self.insert_tensor(tensor, true);
1553 }
1554 }
1555 }
1556
1557 fn insert_data(&mut self, model: &DiscreteModel) {
1558 self.insert_model_index();
1559 self.insert_constants(model);
1560
1561 if let Some(input) = model.input() {
1562 self.insert_tensor(input, false);
1563 }
1564 for tensor in model.input_dep_defns() {
1565 self.insert_tensor(tensor, false);
1566 }
1567 for tensor in model.time_dep_defns() {
1568 self.insert_tensor(tensor, false);
1569 }
1570 for tensor in model.state_dep_defns() {
1571 self.insert_tensor(tensor, false);
1572 }
1573 for tensor in model.state_dep_post_f_defns() {
1574 self.insert_tensor(tensor, false);
1575 }
1576 }
1577
1578 fn pointer_type(context: &'ctx Context, _ty: BasicTypeEnum<'ctx>) -> PointerType<'ctx> {
1579 context.ptr_type(AddressSpace::default())
1580 }
1581
1582 fn fn_pointer_type(context: &'ctx Context, _ty: FunctionType<'ctx>) -> PointerType<'ctx> {
1583 context.ptr_type(AddressSpace::default())
1584 }
1585
1586 fn insert_indices(&mut self) {
1587 if let Some(indices) = self.globals.indices.as_ref() {
1588 let i32_type = self.context.i32_type();
1589 let zero = i32_type.const_int(0, false);
1590 let ptr = unsafe {
1591 indices
1592 .as_pointer_value()
1593 .const_in_bounds_gep(i32_type, &[zero])
1594 };
1595 self.variables.insert("indices".to_owned(), ptr);
1596 }
1597 }
1598
1599 fn insert_model_index(&mut self) {
1600 self.insert_param("model_index", self.globals.model_index.as_pointer_value());
1601 }
1602
1603 fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) {
1604 self.variables.insert(name.to_owned(), value);
1605 }
1606
1607 fn build_gep<T: BasicType<'ctx>>(
1608 &self,
1609 ty: T,
1610 ptr: PointerValue<'ctx>,
1611 ordered_indexes: &[IntValue<'ctx>],
1612 name: &str,
1613 ) -> Result<PointerValue<'ctx>> {
1614 unsafe {
1615 self.builder
1616 .build_gep(ty, ptr, ordered_indexes, name)
1617 .map_err(|e| e.into())
1618 }
1619 }
1620
1621 fn build_load<T: BasicType<'ctx>>(
1622 &self,
1623 ty: T,
1624 ptr: PointerValue<'ctx>,
1625 name: &str,
1626 ) -> Result<BasicValueEnum<'ctx>> {
1627 self.builder.build_load(ty, ptr, name).map_err(|e| e.into())
1628 }
1629
1630 fn get_ptr_to_index<T: BasicType<'ctx>>(
1631 builder: &Builder<'ctx>,
1632 ty: T,
1633 ptr: &PointerValue<'ctx>,
1634 index: IntValue<'ctx>,
1635 name: &str,
1636 ) -> PointerValue<'ctx> {
1637 unsafe {
1638 builder
1639 .build_in_bounds_gep(ty, *ptr, &[index], name)
1640 .unwrap()
1641 }
1642 }
1643
1644 fn insert_state(&mut self, u: &Tensor) {
1645 let mut data_index = 0;
1646 for blk in u.elmts() {
1647 if let Some(name) = blk.name() {
1648 let ptr = self.variables.get("u").unwrap();
1649 let i = self
1650 .context
1651 .i32_type()
1652 .const_int(data_index.try_into().unwrap(), false);
1653 let alloca = Self::get_ptr_to_index(
1654 &self.create_entry_block_builder(),
1655 self.real_type,
1656 ptr,
1657 i,
1658 blk.name().unwrap(),
1659 );
1660 self.variables.insert(name.to_owned(), alloca);
1661 }
1662 data_index += blk.nnz();
1663 }
1664 }
1665 fn insert_dot_state(&mut self, dudt: &Tensor) {
1666 let mut data_index = 0;
1667 for blk in dudt.elmts() {
1668 if let Some(name) = blk.name() {
1669 let ptr = self.variables.get("dudt").unwrap();
1670 let i = self
1671 .context
1672 .i32_type()
1673 .const_int(data_index.try_into().unwrap(), false);
1674 let alloca = Self::get_ptr_to_index(
1675 &self.create_entry_block_builder(),
1676 self.real_type,
1677 ptr,
1678 i,
1679 blk.name().unwrap(),
1680 );
1681 self.variables.insert(name.to_owned(), alloca);
1682 }
1683 data_index += blk.nnz();
1684 }
1685 }
1686 fn insert_tensor(&mut self, tensor: &Tensor, is_constant: bool) {
1687 let var_name = if is_constant { "constants" } else { "data" };
1688 let ptr = *self.variables.get(var_name).unwrap();
1689 let mut data_index = self.layout.get_data_index(tensor.name()).unwrap();
1690 let i = self
1691 .context
1692 .i32_type()
1693 .const_int(data_index.try_into().unwrap(), false);
1694 let alloca = Self::get_ptr_to_index(
1695 &self.create_entry_block_builder(),
1696 self.real_type,
1697 &ptr,
1698 i,
1699 tensor.name(),
1700 );
1701 self.variables.insert(tensor.name().to_owned(), alloca);
1702
1703 for blk in tensor.elmts() {
1705 if let Some(name) = blk.name() {
1706 let i = self
1707 .context
1708 .i32_type()
1709 .const_int(data_index.try_into().unwrap(), false);
1710 let alloca = Self::get_ptr_to_index(
1711 &self.create_entry_block_builder(),
1712 self.real_type,
1713 &ptr,
1714 i,
1715 name,
1716 );
1717 self.variables.insert(name.to_owned(), alloca);
1718 }
1719 data_index += blk.nnz();
1721 }
1722 }
1723 fn get_param(&self, name: &str) -> &PointerValue<'ctx> {
1724 self.variables.get(name).unwrap()
1725 }
1726
1727 fn get_var(&self, tensor: &Tensor) -> &PointerValue<'ctx> {
1728 self.variables.get(tensor.name()).unwrap()
1729 }
1730
1731 fn get_function(&mut self, name: &str) -> Option<FunctionValue<'ctx>> {
1732 if let Some(&func) = self.functions.get(name) {
1734 return Some(func);
1735 }
1736 let current_block = self.builder.get_insert_block().unwrap();
1737 let current_loc = self.builder.get_current_debug_location().unwrap();
1738 let current_function = self.fn_value_opt.unwrap();
1739
1740 let function = match name {
1741 "sin" | "cos" | "tan" | "exp" | "log" | "log10" | "sqrt" | "abs" | "copysign"
1743 | "pow" | "min" | "max" => {
1744 let intrinsic_name = match name {
1745 "min" => "minnum",
1746 "max" => "maxnum",
1747 "abs" => "fabs",
1748 _ => name,
1749 };
1750 let llvm_name =
1751 format!("llvm.{}.{}", intrinsic_name, self.diffsl_real_type.as_str());
1752
1753 let args_types: Vec<BasicTypeEnum> = vec![self.real_type.into()];
1754 self.builder.unset_current_debug_location();
1755
1756 Some(match Intrinsic::find(&llvm_name) {
1758 Some(intrinsic) => intrinsic.get_declaration(&self.module, &args_types)?,
1759 None => {
1760 let args_types_meta: Vec<BasicMetadataTypeEnum> =
1762 vec![self.real_type.into()];
1763 let function_type = self.real_type.fn_type(&args_types_meta, false);
1764 self.module
1765 .add_function(name, function_type, Some(Linkage::External))
1766 }
1767 })
1768 }
1769 "sigmoid" => {
1771 let arg_len = 1;
1772 let ret_type = self.real_type;
1773
1774 let args_types = std::iter::repeat_n(ret_type, arg_len)
1775 .map(|f| f.into())
1776 .collect::<Vec<BasicMetadataTypeEnum>>();
1777
1778 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1779
1780 for arg in fn_val.get_param_iter() {
1781 arg.into_float_value().set_name("x");
1782 }
1783 let _basic_block = self.start_function(fn_val, None);
1784
1785 let x = fn_val.get_nth_param(0)?.into_float_value();
1786 let one = self.real_type.const_float(1.0);
1787 let negx = self.builder.build_float_neg(x, name).ok()?;
1788 let exp = self.get_function("exp").unwrap();
1789 let exp_negx = self
1790 .builder
1791 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1792 .ok()?;
1793 let one_plus_exp_negx = self
1794 .builder
1795 .build_float_add(
1796 exp_negx
1797 .try_as_basic_value()
1798 .unwrap_basic()
1799 .into_float_value(),
1800 one,
1801 name,
1802 )
1803 .ok()?;
1804 let sigmoid = self
1805 .builder
1806 .build_float_div(one, one_plus_exp_negx, name)
1807 .ok()?;
1808 self.builder.build_return(Some(&sigmoid)).ok();
1809 Some(fn_val)
1810 }
1811 "arcsinh" | "arccosh" => {
1812 let arg_len = 1;
1813 let ret_type = self.real_type;
1814
1815 let args_types = std::iter::repeat_n(ret_type, arg_len)
1816 .map(|f| f.into())
1817 .collect::<Vec<BasicMetadataTypeEnum>>();
1818
1819 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1820
1821 for arg in fn_val.get_param_iter() {
1822 arg.into_float_value().set_name("x");
1823 }
1824
1825 let _basic_block = self.start_function(fn_val, None);
1826 let x = fn_val.get_nth_param(0)?.into_float_value();
1827 let one = match name {
1828 "arccosh" => self.real_type.const_float(-1.0),
1829 "arcsinh" => self.real_type.const_float(1.0),
1830 _ => panic!("unknown function"),
1831 };
1832 let x_squared = self.builder.build_float_mul(x, x, name).ok()?;
1833 let one_plus_x_squared = self.builder.build_float_add(x_squared, one, name).ok()?;
1834 let sqrt = self.get_function("sqrt").unwrap();
1835 let sqrt_one_plus_x_squared = self
1836 .builder
1837 .build_call(
1838 sqrt,
1839 &[BasicMetadataValueEnum::FloatValue(one_plus_x_squared)],
1840 name,
1841 )
1842 .unwrap()
1843 .try_as_basic_value()
1844 .unwrap_basic()
1845 .into_float_value();
1846 let x_plus_sqrt_one_plus_x_squared = self
1847 .builder
1848 .build_float_add(x, sqrt_one_plus_x_squared, name)
1849 .ok()?;
1850 let ln = self.get_function("log").unwrap();
1851 let result = self
1852 .builder
1853 .build_call(
1854 ln,
1855 &[BasicMetadataValueEnum::FloatValue(
1856 x_plus_sqrt_one_plus_x_squared,
1857 )],
1858 name,
1859 )
1860 .unwrap()
1861 .try_as_basic_value()
1862 .unwrap_basic()
1863 .into_float_value();
1864 self.builder.build_return(Some(&result)).ok();
1865 Some(fn_val)
1866 }
1867 "heaviside" => {
1868 let arg_len = 1;
1869 let ret_type = self.real_type;
1870
1871 let args_types = std::iter::repeat_n(ret_type, arg_len)
1872 .map(|f| f.into())
1873 .collect::<Vec<BasicMetadataTypeEnum>>();
1874
1875 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1876
1877 for arg in fn_val.get_param_iter() {
1878 arg.into_float_value().set_name("x");
1879 }
1880
1881 let _basic_block = self.start_function(fn_val, None);
1882 let x = fn_val.get_nth_param(0)?.into_float_value();
1883 let zero = self.real_type.const_float(0.0);
1884 let one = self.real_type.const_float(1.0);
1885 let result = self
1886 .builder
1887 .build_select(
1888 self.builder
1889 .build_float_compare(FloatPredicate::OGE, x, zero, "x >= 0")
1890 .unwrap(),
1891 one,
1892 zero,
1893 name,
1894 )
1895 .ok()?;
1896 self.builder.build_return(Some(&result)).ok();
1897 Some(fn_val)
1898 }
1899 "tanh" | "sinh" | "cosh" => {
1900 let arg_len = 1;
1901 let ret_type = self.real_type;
1902
1903 let args_types = std::iter::repeat_n(ret_type, arg_len)
1904 .map(|f| f.into())
1905 .collect::<Vec<BasicMetadataTypeEnum>>();
1906
1907 let fn_val = self.add_function(name, &["x"], &args_types, None, true);
1908
1909 for arg in fn_val.get_param_iter() {
1910 arg.into_float_value().set_name("x");
1911 }
1912
1913 let _basic_block = self.start_function(fn_val, None);
1914 let x = fn_val.get_nth_param(0)?.into_float_value();
1915 let negx = self.builder.build_float_neg(x, name).ok()?;
1916 let exp = self.get_function("exp").unwrap();
1917 let exp_negx = self
1918 .builder
1919 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name)
1920 .ok()?;
1921 let expx = self
1922 .builder
1923 .build_call(exp, &[BasicMetadataValueEnum::FloatValue(x)], name)
1924 .ok()?;
1925 let expx_minus_exp_negx = self
1926 .builder
1927 .build_float_sub(
1928 expx.try_as_basic_value().unwrap_basic().into_float_value(),
1929 exp_negx
1930 .try_as_basic_value()
1931 .unwrap_basic()
1932 .into_float_value(),
1933 name,
1934 )
1935 .ok()?;
1936 let expx_plus_exp_negx = self
1937 .builder
1938 .build_float_add(
1939 expx.try_as_basic_value().unwrap_basic().into_float_value(),
1940 exp_negx
1941 .try_as_basic_value()
1942 .unwrap_basic()
1943 .into_float_value(),
1944 name,
1945 )
1946 .ok()?;
1947 let result = match name {
1948 "tanh" => self
1949 .builder
1950 .build_float_div(expx_minus_exp_negx, expx_plus_exp_negx, name)
1951 .ok()?,
1952 "sinh" => self
1953 .builder
1954 .build_float_div(expx_minus_exp_negx, self.real_type.const_float(2.0), name)
1955 .ok()?,
1956 "cosh" => self
1957 .builder
1958 .build_float_div(expx_plus_exp_negx, self.real_type.const_float(2.0), name)
1959 .ok()?,
1960 _ => panic!("unknown function"),
1961 };
1962 self.builder.build_return(Some(&result)).ok();
1963 Some(fn_val)
1964 }
1965 _ => None,
1966 }?;
1967
1968 if !function.verify(true) {
1969 function.print_to_stderr();
1970 panic!("Invalid generated function for {}", name);
1971 }
1972
1973 self.builder.position_at_end(current_block);
1974 if self.dibuilder.is_some() {
1975 self.builder.set_current_debug_location(current_loc);
1976 }
1977 self.fn_value_opt = Some(current_function);
1978
1979 self.functions.insert(name.to_owned(), function);
1980 Some(function)
1981 }
1982
1983 #[inline]
1985 fn fn_value(&self) -> FunctionValue<'ctx> {
1986 self.fn_value_opt.unwrap()
1987 }
1988
1989 #[inline]
1990 fn tensor_ptr(&self) -> PointerValue<'ctx> {
1991 self.tensor_ptr_opt.unwrap()
1992 }
1993
1994 fn create_entry_block_builder(&self) -> Builder<'ctx> {
1996 let builder = self.context.create_builder();
1997 let entry = self.fn_value().get_first_basic_block().unwrap();
1998 match entry.get_first_instruction() {
1999 Some(first_instr) => builder.position_before(&first_instr),
2000 None => builder.position_at_end(entry),
2001 }
2002 builder
2003 }
2004
2005 fn jit_compile_scalar(
2006 &mut self,
2007 a: &Tensor,
2008 res_ptr_opt: Option<PointerValue<'ctx>>,
2009 ) -> Result<PointerValue<'ctx>> {
2010 let res_type = self.real_type;
2011 let res_ptr = match res_ptr_opt {
2012 Some(ptr) => ptr,
2013 None => self
2014 .create_entry_block_builder()
2015 .build_alloca(res_type, a.name())?,
2016 };
2017 let name = a.name();
2018 let elmt = a.elmts().first().unwrap();
2019
2020 let curr_block = self.builder.get_insert_block().unwrap();
2022 let mut next_block_opt = None;
2023 if self.threaded {
2024 let next_block = self.context.append_basic_block(self.fn_value(), "next");
2025 self.builder.position_at_end(next_block);
2026 next_block_opt = Some(next_block);
2027 }
2028
2029 let zero = self.int_type.const_zero();
2030 let float_value = self.jit_compile_expr(name, elmt.expr(), &[], elmt, zero)?;
2031 self.builder.build_store(res_ptr, float_value)?;
2032
2033 if self.threaded {
2035 let exit_block = self.context.append_basic_block(self.fn_value(), "exit");
2036 self.builder.build_unconditional_branch(exit_block)?;
2037 self.builder.position_at_end(curr_block);
2038
2039 let thread_id = self.get_param("thread_id");
2040 let thread_id = self
2041 .builder
2042 .build_load(self.int_type, *thread_id, "thread_id")
2043 .unwrap()
2044 .into_int_value();
2045 let is_first_thread = self.builder.build_int_compare(
2046 IntPredicate::EQ,
2047 thread_id,
2048 self.int_type.const_zero(),
2049 "is_first_thread",
2050 )?;
2051 self.builder.build_conditional_branch(
2052 is_first_thread,
2053 next_block_opt.unwrap(),
2054 exit_block,
2055 )?;
2056
2057 self.builder.position_at_end(exit_block);
2058 }
2059
2060 Ok(res_ptr)
2061 }
2062
2063 fn jit_compile_tensor(
2064 &mut self,
2065 a: &Tensor,
2066 res_ptr_opt: Option<PointerValue<'ctx>>,
2067 code: Option<&str>,
2068 ) -> Result<PointerValue<'ctx>> {
2069 if a.rank() == 0
2071 && a.elmts()
2072 .first()
2073 .is_none_or(|b| b.expr_layout().rank() == 0)
2074 {
2075 return self.jit_compile_scalar(a, res_ptr_opt);
2076 }
2077
2078 let res_type = self.real_type;
2079 let res_ptr = match res_ptr_opt {
2080 Some(ptr) => ptr,
2081 None => self
2082 .create_entry_block_builder()
2083 .build_alloca(res_type, a.name())?,
2084 };
2085
2086 self.tensor_ptr_opt = Some(res_ptr);
2088
2089 for (i, blk) in a.elmts().iter().enumerate() {
2090 if blk.has_values() {
2091 continue;
2092 }
2093 let default = format!("{}-{}", a.name(), i);
2094 let name = blk.name().unwrap_or(default.as_str());
2095 self.jit_compile_block(name, a, blk, code)?;
2096 }
2097 Ok(res_ptr)
2098 }
2099
2100 fn jit_compile_block(
2101 &mut self,
2102 name: &str,
2103 tensor: &Tensor,
2104 elmt: &TensorBlock,
2105 code: Option<&str>,
2106 ) -> Result<()> {
2107 let translation = Translation::new(
2108 elmt.expr_layout(),
2109 elmt.layout(),
2110 elmt.start(),
2111 tensor.layout_ptr(),
2112 );
2113
2114 if let Some(dibuilder) = &self.dibuilder {
2115 let (line_no, column_no) = match (elmt.expr().span, code) {
2116 (Some(s), Some(c)) => {
2117 let s = Span::new(c, s.pos_start, s.pos_end).unwrap();
2118 s.start_pos().line_col()
2119 }
2120 _ => (0, 0),
2121 };
2122 let scope = self
2123 .fn_value_opt
2124 .unwrap()
2125 .get_subprogram()
2126 .unwrap()
2127 .as_debug_info_scope();
2128 let loc = dibuilder.create_debug_location(
2129 self.context,
2130 line_no.try_into().unwrap(),
2131 column_no.try_into().unwrap(),
2132 scope,
2133 None,
2134 );
2135 self.builder.set_current_debug_location(loc);
2136 }
2137
2138 if elmt.expr_layout().is_dense()
2141 || matches!(translation.source, TranslationFrom::DenseContraction { .. })
2142 {
2143 self.jit_compile_dense_block(name, elmt, &translation)
2144 } else if elmt.expr_layout().is_diagonal() {
2145 self.jit_compile_diagonal_block(name, elmt, &translation)
2146 } else if elmt.expr_layout().is_sparse() {
2147 match translation.source {
2148 TranslationFrom::SparseContraction { .. } => {
2149 self.jit_compile_sparse_contraction_block(name, elmt, &translation)
2150 }
2151 _ => self.jit_compile_sparse_block(name, elmt, &translation),
2152 }
2153 } else {
2154 Err(anyhow!(
2155 "unsupported block layout: {:?}",
2156 elmt.expr_layout()
2157 ))
2158 }
2159 }
2160
2161 fn jit_compile_dense_block(
2163 &mut self,
2164 name: &str,
2165 elmt: &TensorBlock,
2166 translation: &Translation,
2167 ) -> Result<()> {
2168 let int_type = self.int_type;
2169
2170 let mut preblock = self.builder.get_insert_block().unwrap();
2171 let expr_rank = elmt.expr_layout().rank();
2172 let expr_shape = elmt
2173 .expr_layout()
2174 .shape()
2175 .mapv(|n| int_type.const_int(n.try_into().unwrap(), false));
2176 let one = int_type.const_int(1, false);
2177
2178 let mut expr_strides = vec![1; expr_rank];
2179 if expr_rank > 0 {
2180 for i in (0..expr_rank - 1).rev() {
2181 expr_strides[i] = expr_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
2182 }
2183 }
2184 let expr_strides = expr_strides
2185 .iter()
2186 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
2187 .collect::<Vec<IntValue>>();
2188
2189 let mut indices = Vec::new();
2191 let mut blocks = Vec::new();
2192
2193 let (contract_sum, contract_by, contract_strides) =
2195 if let TranslationFrom::DenseContraction {
2196 contract_by,
2197 contract_len: _,
2198 } = translation.source
2199 {
2200 let contract_rank = expr_rank - contract_by;
2201 let mut contract_strides = vec![1; contract_rank];
2202 for i in (0..contract_rank.saturating_sub(1)).rev() {
2203 contract_strides[i] =
2204 contract_strides[i + 1] * elmt.expr_layout().shape()[i + 1];
2205 }
2206 let contract_strides = contract_strides
2207 .iter()
2208 .map(|&s| int_type.const_int(s.try_into().unwrap(), false))
2209 .collect::<Vec<IntValue>>();
2210 (
2211 Some(self.builder.build_alloca(self.real_type, "contract_sum")?),
2212 contract_by,
2213 Some(contract_strides),
2214 )
2215 } else {
2216 (None, 0, None)
2217 };
2218
2219 let (thread_start, thread_end, test_done, next) = if self.threaded {
2221 let (start, end, test_done, next) =
2222 self.jit_threading_limits(*expr_shape.get(0).unwrap_or(&one))?;
2223 preblock = next;
2224 (Some(start), Some(end), Some(test_done), Some(next))
2225 } else {
2226 (None, None, None, None)
2227 };
2228
2229 if contract_by == expr_rank {
2231 if let Some(contract_sum) = contract_sum {
2232 self.builder
2233 .build_store(contract_sum, self.real_type.const_zero())?;
2234 }
2235 }
2236
2237 for i in 0..expr_rank {
2238 let block = self.context.append_basic_block(self.fn_value(), name);
2239 self.builder.build_unconditional_branch(block)?;
2240 self.builder.position_at_end(block);
2241
2242 let start_index = if i == 0 && self.threaded {
2243 thread_start.unwrap()
2244 } else {
2245 self.int_type.const_zero()
2246 };
2247
2248 let curr_index = self.builder.build_phi(int_type, format!["i{i}"].as_str())?;
2249 curr_index.add_incoming(&[(&start_index, preblock)]);
2250
2251 let init_point = (expr_rank - contract_by).saturating_sub(1);
2252 if i == init_point && contract_by != expr_rank {
2253 if let Some(contract_sum) = contract_sum {
2254 self.builder
2255 .build_store(contract_sum, self.real_type.const_zero())?;
2256 }
2257 }
2258
2259 indices.push(curr_index);
2260 blocks.push(block);
2261 preblock = block;
2262 }
2263
2264 let indices_int: Vec<IntValue> = indices
2265 .iter()
2266 .map(|i| i.as_basic_value().into_int_value())
2267 .collect();
2268
2269 let mut expr_index = *indices_int.last().unwrap_or(&int_type.const_zero());
2271 let mut stride = 1u64;
2272 if !indices.is_empty() {
2273 for i in (0..indices.len() - 1).rev() {
2274 let iname_i = indices_int[i];
2275 let shapei: u64 = elmt.expr_layout().shape()[i + 1].try_into().unwrap();
2276 stride *= shapei;
2277 let stride_intval = self.context.i32_type().const_int(stride, false);
2278 let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?;
2279 expr_index = self.builder.build_int_add(expr_index, stride_mul_i, name)?;
2280 }
2281 }
2282
2283 let float_value =
2284 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2285
2286 if let Some(contract_sum) = contract_sum {
2287 let contract_sum_value = self
2288 .build_load(self.real_type, contract_sum, "contract_sum")?
2289 .into_float_value();
2290 let new_contract_sum_value = self.builder.build_float_add(
2291 contract_sum_value,
2292 float_value,
2293 "new_contract_sum",
2294 )?;
2295 self.builder
2296 .build_store(contract_sum, new_contract_sum_value)?;
2297 } else {
2298 let expr_index = indices_int.iter().zip(expr_strides.iter()).fold(
2299 self.int_type.const_zero(),
2300 |acc, (i, s)| {
2301 let tmp = self.builder.build_int_mul(*i, *s, "expr_index").unwrap();
2302 self.builder.build_int_add(acc, tmp, "acc").unwrap()
2303 },
2304 );
2305 self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation)?;
2306 }
2307
2308 let mut postblock = self.builder.get_insert_block().unwrap();
2309
2310 for i in (0..expr_rank).rev() {
2312 let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
2314 indices[i].add_incoming(&[(&next_index, postblock)]);
2315 let store_point = (expr_rank - contract_by).saturating_sub(1);
2316 if i == store_point {
2317 if let Some(contract_sum) = contract_sum {
2318 let contract_sum_value = self
2319 .build_load(self.real_type, contract_sum, "contract_sum")?
2320 .into_float_value();
2321 let contract_strides = contract_strides.as_ref().unwrap();
2322 let elmt_index = indices_int
2323 .iter()
2324 .take(contract_strides.len())
2325 .zip(contract_strides.iter())
2326 .fold(self.int_type.const_zero(), |acc, (i, s)| {
2327 let tmp = self.builder.build_int_mul(*i, *s, "elmt_index").unwrap();
2328 self.builder.build_int_add(acc, tmp, "acc").unwrap()
2329 });
2330 self.jit_compile_store(
2331 name,
2332 elmt,
2333 elmt_index,
2334 contract_sum_value,
2335 translation,
2336 )?;
2337 }
2338 }
2339
2340 let end_index = if i == 0 && self.threaded {
2341 thread_end.unwrap()
2342 } else {
2343 expr_shape[i]
2344 };
2345
2346 let loop_while =
2348 self.builder
2349 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
2350 let block = self.context.append_basic_block(self.fn_value(), name);
2351 self.builder
2352 .build_conditional_branch(loop_while, blocks[i], block)?;
2353 self.builder.position_at_end(block);
2354 postblock = block;
2355 }
2356
2357 if self.threaded {
2358 self.jit_end_threading(
2359 thread_start.unwrap(),
2360 thread_end.unwrap(),
2361 test_done.unwrap(),
2362 next.unwrap(),
2363 )?;
2364 }
2365 Ok(())
2366 }
2367
2368 fn jit_compile_sparse_contraction_block(
2369 &mut self,
2370 name: &str,
2371 elmt: &TensorBlock,
2372 translation: &Translation,
2373 ) -> Result<()> {
2374 match translation.source {
2375 TranslationFrom::SparseContraction { .. } => {}
2376 _ => {
2377 panic!("expected sparse contraction")
2378 }
2379 }
2380 let int_type = self.int_type;
2381
2382 let translation_index = self
2383 .layout
2384 .get_translation_index(elmt.expr_layout(), elmt.layout())
2385 .unwrap();
2386 let translation_index = translation_index + translation.get_from_index_in_data_layout();
2387
2388 let final_contract_index =
2389 int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false);
2390 let (thread_start, thread_end, test_done, next) = if self.threaded {
2391 let (start, end, test_done, next) = self.jit_threading_limits(final_contract_index)?;
2392 (Some(start), Some(end), Some(test_done), Some(next))
2393 } else {
2394 (None, None, None, None)
2395 };
2396
2397 let preblock = self.builder.get_insert_block().unwrap();
2398 let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?;
2399
2400 let block = self.context.append_basic_block(self.fn_value(), name);
2402 self.builder.build_unconditional_branch(block)?;
2403 self.builder.position_at_end(block);
2404
2405 let contract_index = self.builder.build_phi(int_type, "i")?;
2406 let contract_start = if self.threaded {
2407 thread_start.unwrap()
2408 } else {
2409 int_type.const_zero()
2410 };
2411 contract_index.add_incoming(&[(&contract_start, preblock)]);
2412
2413 let start_index = self.builder.build_int_add(
2414 int_type.const_int(translation_index.try_into().unwrap(), false),
2415 self.builder.build_int_mul(
2416 int_type.const_int(2, false),
2417 contract_index.as_basic_value().into_int_value(),
2418 name,
2419 )?,
2420 name,
2421 )?;
2422 let end_index =
2423 self.builder
2424 .build_int_add(start_index, int_type.const_int(1, false), name)?;
2425 let start_ptr = self.build_gep(
2426 self.int_type,
2427 *self.get_param("indices"),
2428 &[start_index],
2429 "start_index_ptr",
2430 )?;
2431 let start_contract = self
2432 .build_load(self.int_type, start_ptr, "start")?
2433 .into_int_value();
2434 let end_ptr = self.build_gep(
2435 self.int_type,
2436 *self.get_param("indices"),
2437 &[end_index],
2438 "end_index_ptr",
2439 )?;
2440 let end_contract = self
2441 .build_load(self.int_type, end_ptr, "end")?
2442 .into_int_value();
2443
2444 self.builder
2446 .build_store(contract_sum_ptr, self.real_type.const_float(0.0))?;
2447
2448 let start_contract_block = self
2450 .context
2451 .append_basic_block(self.fn_value(), format!("{name}_contract").as_str());
2452 self.builder
2453 .build_unconditional_branch(start_contract_block)?;
2454 self.builder.position_at_end(start_contract_block);
2455
2456 let expr_index_phi = self.builder.build_phi(int_type, "j")?;
2457 expr_index_phi.add_incoming(&[(&start_contract, block)]);
2458
2459 let expr_index = expr_index_phi.as_basic_value().into_int_value();
2460 let indices_int = self.expr_indices_from_elmt_index(expr_index, elmt, name)?;
2461
2462 let float_value =
2464 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, expr_index)?;
2465 let contract_sum_value = self
2466 .build_load(self.real_type, contract_sum_ptr, "contract_sum")?
2467 .into_float_value();
2468 let new_contract_sum_value =
2469 self.builder
2470 .build_float_add(contract_sum_value, float_value, "new_contract_sum")?;
2471 self.builder
2472 .build_store(contract_sum_ptr, new_contract_sum_value)?;
2473
2474 let end_contract_block = self.builder.get_insert_block().unwrap();
2475
2476 let next_elmt_index =
2478 self.builder
2479 .build_int_add(expr_index, int_type.const_int(1, false), name)?;
2480 expr_index_phi.add_incoming(&[(&next_elmt_index, end_contract_block)]);
2481
2482 let loop_while = self.builder.build_int_compare(
2484 IntPredicate::ULT,
2485 next_elmt_index,
2486 end_contract,
2487 name,
2488 )?;
2489 let post_contract_block = self.context.append_basic_block(self.fn_value(), name);
2490 self.builder.build_conditional_branch(
2491 loop_while,
2492 start_contract_block,
2493 post_contract_block,
2494 )?;
2495 self.builder.position_at_end(post_contract_block);
2496
2497 self.jit_compile_store(
2499 name,
2500 elmt,
2501 contract_index.as_basic_value().into_int_value(),
2502 new_contract_sum_value,
2503 translation,
2504 )?;
2505
2506 let next_contract_index = self.builder.build_int_add(
2508 contract_index.as_basic_value().into_int_value(),
2509 int_type.const_int(1, false),
2510 name,
2511 )?;
2512 contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);
2513
2514 let loop_while = self.builder.build_int_compare(
2516 IntPredicate::ULT,
2517 next_contract_index,
2518 thread_end.unwrap_or(final_contract_index),
2519 name,
2520 )?;
2521 let post_block = self.context.append_basic_block(self.fn_value(), name);
2522 self.builder
2523 .build_conditional_branch(loop_while, block, post_block)?;
2524 self.builder.position_at_end(post_block);
2525
2526 if self.threaded {
2527 self.jit_end_threading(
2528 thread_start.unwrap(),
2529 thread_end.unwrap(),
2530 test_done.unwrap(),
2531 next.unwrap(),
2532 )?;
2533 }
2534
2535 Ok(())
2536 }
2537
2538 fn expr_indices_from_elmt_index(
2539 &mut self,
2540 elmt_index: IntValue<'ctx>,
2541 elmt: &TensorBlock,
2542 name: &str,
2543 ) -> Result<Vec<IntValue<'ctx>>, anyhow::Error> {
2544 let layout_index = self.layout.get_layout_index(elmt.expr_layout()).unwrap();
2545 let int_type = self.int_type;
2546 let elmt_index_mult_rank = self.builder.build_int_mul(
2548 elmt_index,
2549 int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false),
2550 name,
2551 )?;
2552 (0..elmt.expr_layout().rank())
2553 .map(|i| {
2554 let layout_index_plus_offset =
2555 int_type.const_int((layout_index + i).try_into().unwrap(), false);
2556 let curr_index = self.builder.build_int_add(
2557 elmt_index_mult_rank,
2558 layout_index_plus_offset,
2559 name,
2560 )?;
2561 let ptr = Self::get_ptr_to_index(
2562 &self.builder,
2563 self.int_type,
2564 self.get_param("indices"),
2565 curr_index,
2566 name,
2567 );
2568 Ok(self.build_load(self.int_type, ptr, name)?.into_int_value())
2569 })
2570 .collect::<Result<Vec<_>, anyhow::Error>>()
2571 }
2572
2573 fn jit_compile_sparse_block(
2576 &mut self,
2577 name: &str,
2578 elmt: &TensorBlock,
2579 translation: &Translation,
2580 ) -> Result<()> {
2581 let int_type = self.int_type;
2582
2583 let start_index = int_type.const_int(0, false);
2584 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2585
2586 let (thread_start, thread_end, test_done, next) = if self.threaded {
2587 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2588 (Some(start), Some(end), Some(test_done), Some(next))
2589 } else {
2590 (None, None, None, None)
2591 };
2592
2593 let preblock = self.builder.get_insert_block().unwrap();
2595 let loop_block = self.context.append_basic_block(self.fn_value(), name);
2596 self.builder.build_unconditional_branch(loop_block)?;
2597 self.builder.position_at_end(loop_block);
2598
2599 let curr_index = self.builder.build_phi(int_type, "i")?;
2600 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2601
2602 let elmt_index = curr_index.as_basic_value().into_int_value();
2603 let indices_int = self.expr_indices_from_elmt_index(elmt_index, elmt, name)?;
2604
2605 let float_value =
2607 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2608
2609 self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2610
2611 let end_loop_block = self.builder.get_insert_block().unwrap();
2613
2614 let one = int_type.const_int(1, false);
2616 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2617 curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2618
2619 let loop_while = self.builder.build_int_compare(
2621 IntPredicate::ULT,
2622 next_index,
2623 thread_end.unwrap_or(end_index),
2624 name,
2625 )?;
2626 let post_block = self.context.append_basic_block(self.fn_value(), name);
2627 self.builder
2628 .build_conditional_branch(loop_while, loop_block, post_block)?;
2629 self.builder.position_at_end(post_block);
2630
2631 if self.threaded {
2632 self.jit_end_threading(
2633 thread_start.unwrap(),
2634 thread_end.unwrap(),
2635 test_done.unwrap(),
2636 next.unwrap(),
2637 )?;
2638 }
2639
2640 Ok(())
2641 }
2642
2643 fn jit_compile_diagonal_block(
2645 &mut self,
2646 name: &str,
2647 elmt: &TensorBlock,
2648 translation: &Translation,
2649 ) -> Result<()> {
2650 let int_type = self.int_type;
2651
2652 let start_index = int_type.const_int(0, false);
2653 let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false);
2654
2655 let (thread_start, thread_end, test_done, next) = if self.threaded {
2656 let (start, end, test_done, next) = self.jit_threading_limits(end_index)?;
2657 (Some(start), Some(end), Some(test_done), Some(next))
2658 } else {
2659 (None, None, None, None)
2660 };
2661
2662 let preblock = self.builder.get_insert_block().unwrap();
2664 let start_loop_block = self.context.append_basic_block(self.fn_value(), name);
2665 self.builder.build_unconditional_branch(start_loop_block)?;
2666 self.builder.position_at_end(start_loop_block);
2667
2668 let curr_index = self.builder.build_phi(int_type, "i")?;
2669 curr_index.add_incoming(&[(&thread_start.unwrap_or(start_index), preblock)]);
2670
2671 let elmt_index = curr_index.as_basic_value().into_int_value();
2673 let indices_int: Vec<IntValue> =
2674 (0..elmt.expr_layout().rank()).map(|_| elmt_index).collect();
2675
2676 let float_value =
2678 self.jit_compile_expr(name, elmt.expr(), indices_int.as_slice(), elmt, elmt_index)?;
2679
2680 self.jit_compile_broadcast_and_store(name, elmt, float_value, elmt_index, translation)?;
2682
2683 let end_loop_block = self.builder.get_insert_block().unwrap();
2684
2685 let one = int_type.const_int(1, false);
2687 let next_index = self.builder.build_int_add(elmt_index, one, name)?;
2688 curr_index.add_incoming(&[(&next_index, end_loop_block)]);
2689
2690 let loop_while = self.builder.build_int_compare(
2692 IntPredicate::ULT,
2693 next_index,
2694 thread_end.unwrap_or(end_index),
2695 name,
2696 )?;
2697 let post_block = self.context.append_basic_block(self.fn_value(), name);
2698 self.builder
2699 .build_conditional_branch(loop_while, start_loop_block, post_block)?;
2700 self.builder.position_at_end(post_block);
2701
2702 if self.threaded {
2703 self.jit_end_threading(
2704 thread_start.unwrap(),
2705 thread_end.unwrap(),
2706 test_done.unwrap(),
2707 next.unwrap(),
2708 )?;
2709 }
2710
2711 Ok(())
2712 }
2713
2714 fn jit_compile_broadcast_and_store(
2715 &mut self,
2716 name: &str,
2717 elmt: &TensorBlock,
2718 float_value: FloatValue<'ctx>,
2719 expr_index: IntValue<'ctx>,
2720 translation: &Translation,
2721 ) -> Result<()> {
2722 let int_type = self.int_type;
2723 let one = int_type.const_int(1, false);
2724 let zero = int_type.const_int(0, false);
2725 let pre_block = self.builder.get_insert_block().unwrap();
2726 match translation.source {
2727 TranslationFrom::Broadcast {
2728 broadcast_by: _,
2729 broadcast_len,
2730 } => {
2731 let bcast_start_index = zero;
2732 let bcast_end_index = int_type.const_int(broadcast_len.try_into().unwrap(), false);
2733
2734 let bcast_block = self.context.append_basic_block(self.fn_value(), name);
2736 self.builder.build_unconditional_branch(bcast_block)?;
2737 self.builder.position_at_end(bcast_block);
2738 let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?;
2739 bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]);
2740
2741 let store_index = self.builder.build_int_add(
2743 self.builder
2744 .build_int_mul(expr_index, bcast_end_index, "store_index")?,
2745 bcast_index.as_basic_value().into_int_value(),
2746 "bcast_store_index",
2747 )?;
2748 self.jit_compile_store(name, elmt, store_index, float_value, translation)?;
2749
2750 let bcast_next_index = self.builder.build_int_add(
2752 bcast_index.as_basic_value().into_int_value(),
2753 one,
2754 name,
2755 )?;
2756 bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);
2757
2758 let bcast_cond = self.builder.build_int_compare(
2760 IntPredicate::ULT,
2761 bcast_next_index,
2762 bcast_end_index,
2763 "broadcast_cond",
2764 )?;
2765 let post_bcast_block = self.context.append_basic_block(self.fn_value(), name);
2766 self.builder
2767 .build_conditional_branch(bcast_cond, bcast_block, post_bcast_block)?;
2768 self.builder.position_at_end(post_bcast_block);
2769 Ok(())
2770 }
2771 TranslationFrom::ElementWise | TranslationFrom::DiagonalContraction { .. } => {
2772 self.jit_compile_store(name, elmt, expr_index, float_value, translation)?;
2773 Ok(())
2774 }
2775 _ => Err(anyhow!("Invalid translation")),
2776 }
2777 }
2778
2779 fn jit_compile_store(
2780 &mut self,
2781 name: &str,
2782 elmt: &TensorBlock,
2783 store_index: IntValue<'ctx>,
2784 float_value: FloatValue<'ctx>,
2785 translation: &Translation,
2786 ) -> Result<()> {
2787 let int_type = self.int_type;
2788 let res_index = match &translation.target {
2789 TranslationTo::Contiguous { start, end: _ } => {
2790 let start_const = int_type.const_int((*start).try_into().unwrap(), false);
2791 self.builder.build_int_add(start_const, store_index, name)?
2792 }
2793 TranslationTo::Sparse { indices: _ } => {
2794 let translate_index = self
2796 .layout
2797 .get_translation_index(elmt.expr_layout(), elmt.layout())
2798 .unwrap();
2799 let translate_store_index =
2800 translate_index + translation.get_to_index_in_data_layout();
2801 let translate_store_index =
2802 int_type.const_int(translate_store_index.try_into().unwrap(), false);
2803 let elmt_index_strided = store_index;
2804 let curr_index =
2805 self.builder
2806 .build_int_add(elmt_index_strided, translate_store_index, name)?;
2807 let ptr = Self::get_ptr_to_index(
2808 &self.builder,
2809 self.int_type,
2810 self.get_param("indices"),
2811 curr_index,
2812 name,
2813 );
2814 self.build_load(self.int_type, ptr, name)?.into_int_value()
2815 }
2816 };
2817
2818 let resi_ptr = Self::get_ptr_to_index(
2819 &self.builder,
2820 self.real_type,
2821 &self.tensor_ptr(),
2822 res_index,
2823 name,
2824 );
2825 self.builder.build_store(resi_ptr, float_value)?;
2826 Ok(())
2827 }
2828
2829 fn jit_compile_expr(
2830 &mut self,
2831 name: &str,
2832 expr: &Ast,
2833 index: &[IntValue<'ctx>],
2834 elmt: &TensorBlock,
2835 expr_index: IntValue<'ctx>,
2836 ) -> Result<FloatValue<'ctx>> {
2837 let name = elmt.name().unwrap_or(name);
2838 match &expr.kind {
2839 AstKind::Binop(binop) => {
2840 let lhs =
2841 self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?;
2842 let rhs =
2843 self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?;
2844 match binop.op {
2845 '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?),
2846 '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?),
2847 '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?),
2848 '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?),
2849 unknown => Err(anyhow!("unknown binop op '{}'", unknown)),
2850 }
2851 }
2852 AstKind::Monop(monop) => {
2853 let child =
2854 self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?;
2855 match monop.op {
2856 '-' => Ok(self.builder.build_float_neg(child, name)?),
2857 unknown => Err(anyhow!("unknown monop op '{}'", unknown)),
2858 }
2859 }
2860 AstKind::Call(call) => match self.get_function(call.fn_name) {
2861 Some(function) => {
2862 let mut args: Vec<BasicMetadataValueEnum> = Vec::new();
2863 for arg in call.args.iter() {
2864 let arg_val =
2865 self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?;
2866 args.push(BasicMetadataValueEnum::FloatValue(arg_val));
2867 }
2868 let ret_value = self
2869 .builder
2870 .build_call(function, args.as_slice(), name)?
2871 .try_as_basic_value()
2872 .unwrap_basic()
2873 .into_float_value();
2874 Ok(ret_value)
2875 }
2876 None => Err(anyhow!("unknown function call '{}'", call.fn_name)),
2877 },
2878 AstKind::CallArg(arg) => {
2879 self.jit_compile_expr(name, &arg.expression, index, elmt, expr_index)
2880 }
2881 AstKind::Number(value) => Ok(self.real_type.const_float(*value)),
2882 AstKind::Name(iname) => {
2883 if iname.name == "N" {
2884 if iname.is_tangent {
2885 return Ok(self.real_type.const_float(0.0));
2886 }
2887 let model_index = self
2888 .build_load(self.int_type, *self.get_param("model_index"), "model_index")?
2889 .into_int_value();
2890 let n_value = self.builder.build_signed_int_to_float(
2891 model_index,
2892 self.real_type,
2893 "n_as_real",
2894 )?;
2895 return Ok(n_value);
2896 }
2897 let ptr = *self.get_param(iname.name);
2898 let layout = self.layout.get_layout(iname.name).unwrap().clone();
2899 let iname_elmt_index = if layout.is_dense() {
2900 let mut no_transform = true;
2902 let mut iname_index = Vec::new();
2903 for (i, c) in iname.indices.iter().enumerate() {
2904 let pi = elmt
2907 .indices()
2908 .iter()
2909 .position(|x| x == c)
2910 .unwrap_or(elmt.indices().len());
2911 if let Some(indice_ast) = iname.indice.as_ref() {
2913 let Some(indice) = indice_ast.kind.as_indice() else {
2914 return Err(anyhow!("invalid index expression '{}'", indice_ast));
2915 };
2916 let start_intval =
2917 self.jit_compile_integer_expr(indice.first.as_ref(), name)?;
2918 let index_pi = if indice.last.is_some() {
2919 let index_pi = if pi >= index.len() {
2921 self.context.i32_type().const_int(0, false)
2922 } else {
2923 index[pi]
2924 };
2925 self.builder.build_int_add(index_pi, start_intval, name)?
2926 } else {
2927 start_intval
2929 };
2930 iname_index.push(index_pi);
2931 } else {
2932 let index_pi = if pi >= index.len() {
2933 self.context.i32_type().const_int(0, false)
2934 } else {
2935 index[pi]
2936 };
2937 iname_index.push(index_pi);
2938 }
2939 no_transform = no_transform && pi == i;
2940 }
2941 let iname_index: Vec<IntValue> = iname_index
2943 .iter()
2944 .enumerate()
2945 .map(|(i, &idx)| {
2946 if layout.shape()[i] == 1 {
2947 self.context.i32_type().const_int(0, false)
2948 } else {
2949 idx
2950 }
2951 })
2952 .collect();
2953 if !iname_index.is_empty() {
2956 let mut iname_elmt_index = *iname_index.last().unwrap();
2958 let mut stride = 1u64;
2959 for i in (0..iname_index.len() - 1).rev() {
2960 let iname_i = iname_index[i];
2961 let shapei: u64 = layout.shape()[i + 1].try_into().unwrap();
2962 stride *= shapei;
2963 let stride_intval = self.context.i32_type().const_int(stride, false);
2964 let stride_mul_i =
2965 self.builder.build_int_mul(stride_intval, iname_i, name)?;
2966 iname_elmt_index =
2967 self.builder
2968 .build_int_add(iname_elmt_index, stride_mul_i, name)?;
2969 }
2970 iname_elmt_index
2971 } else {
2972 let zero = self.context.i32_type().const_int(0, false);
2974 zero
2975 }
2976 } else if layout.is_sparse() || layout.is_diagonal() {
2977 let expr_layout = elmt.expr_layout();
2978
2979 if expr_layout != &layout {
2980 let permutation =
2987 DataLayout::permutation(elmt, iname.indices.as_slice(), &layout);
2988 if let Some(base_binary_layout_index) =
2989 self.layout
2990 .get_binary_layout_index(&layout, expr_layout, permutation)
2991 {
2992 let binary_layout_index = self.builder.build_int_add(
2993 self.int_type
2994 .const_int(base_binary_layout_index.try_into().unwrap(), false),
2995 expr_index,
2996 name,
2997 )?;
2998
2999 let indices_ptr = Self::get_ptr_to_index(
3000 &self.builder,
3001 self.int_type,
3002 self.get_param("indices"),
3003 binary_layout_index,
3004 name,
3005 );
3006 let mapped_index = self
3007 .build_load(self.int_type, indices_ptr, name)?
3008 .into_int_value();
3009 let is_less_than_zero = self.builder.build_int_compare(
3010 IntPredicate::SLT,
3011 mapped_index,
3012 self.int_type.const_int(0, true),
3013 "sparse_index_check",
3014 )?;
3015 let is_less_than_zero_block =
3016 self.context.append_basic_block(self.fn_value(), "lt_zero");
3017 let not_less_than_zero_block = self
3018 .context
3019 .append_basic_block(self.fn_value(), "not_lt_zero");
3020 let merge_block =
3021 self.context.append_basic_block(self.fn_value(), "merge");
3022 self.builder.build_conditional_branch(
3023 is_less_than_zero,
3024 is_less_than_zero_block,
3025 not_less_than_zero_block,
3026 )?;
3027
3028 self.builder.position_at_end(is_less_than_zero_block);
3030 let zero_value = self.real_type.const_float(0.);
3031 self.builder.build_unconditional_branch(merge_block)?;
3032
3033 self.builder.position_at_end(not_less_than_zero_block);
3035 let value_ptr = Self::get_ptr_to_index(
3036 &self.builder,
3037 self.real_type,
3038 &ptr,
3039 mapped_index,
3040 name,
3041 );
3042 let value = self
3043 .build_load(self.real_type, value_ptr, name)?
3044 .into_float_value();
3045 self.builder.build_unconditional_branch(merge_block)?;
3046
3047 self.builder.position_at_end(merge_block);
3049 let if_return_value =
3050 self.builder.build_phi(self.real_type, "sparse_value")?;
3051 if_return_value.add_incoming(&[(&zero_value, is_less_than_zero_block)]);
3052 if_return_value.add_incoming(&[(&value, not_less_than_zero_block)]);
3053
3054 let phi_value = if_return_value.as_basic_value().into_float_value();
3055 return Ok(phi_value);
3056 } else {
3057 expr_index
3058 }
3059 } else {
3060 expr_index
3062 }
3063 } else {
3064 panic!("unexpected layout");
3065 };
3066 let value_ptr = Self::get_ptr_to_index(
3067 &self.builder,
3068 self.real_type,
3069 &ptr,
3070 iname_elmt_index,
3071 name,
3072 );
3073 Ok(self
3074 .build_load(self.real_type, value_ptr, name)?
3075 .into_float_value())
3076 }
3077 AstKind::NamedGradient(name) => {
3078 let name_str = name.to_string();
3079 let ptr = self.get_param(name_str.as_str());
3080 Ok(self
3081 .build_load(self.real_type, *ptr, name_str.as_str())?
3082 .into_float_value())
3083 }
3084 AstKind::Index(_) => todo!(),
3085 AstKind::Slice(_) => todo!(),
3086 AstKind::Integer(_) => todo!(),
3087 _ => panic!("unexprected astkind"),
3088 }
3089 }
3090
3091 fn jit_compile_integer_expr(&mut self, expr: &Ast, name: &str) -> Result<IntValue<'ctx>> {
3092 match &expr.kind {
3093 AstKind::Integer(value) => Ok(self.int_type.const_int(*value as u64, true)),
3094 AstKind::Number(value) => {
3095 if value.fract() != 0.0 {
3096 return Err(anyhow!(
3097 "non-integer value '{}' in integer expression",
3098 value
3099 ));
3100 }
3101 Ok(self.int_type.const_int(*value as u64, true))
3102 }
3103 AstKind::Name(iname) => {
3104 if iname.name == "N" {
3105 Ok(self
3106 .build_load(self.int_type, *self.get_param("model_index"), name)?
3107 .into_int_value())
3108 } else {
3109 Err(anyhow!(
3110 "unsupported name '{}' in integer expression",
3111 iname.name
3112 ))
3113 }
3114 }
3115 AstKind::Monop(monop) => {
3116 let child = self.jit_compile_integer_expr(monop.child.as_ref(), name)?;
3117 match monop.op {
3118 '+' => Ok(child),
3119 '-' => self.builder.build_int_neg(child, name).map_err(Into::into),
3120 _ => Err(anyhow!("unknown integer unary op '{}'", monop.op)),
3121 }
3122 }
3123 AstKind::Binop(binop) => {
3124 let lhs = self.jit_compile_integer_expr(binop.left.as_ref(), name)?;
3125 let rhs = self.jit_compile_integer_expr(binop.right.as_ref(), name)?;
3126 match binop.op {
3127 '+' => self
3128 .builder
3129 .build_int_add(lhs, rhs, name)
3130 .map_err(Into::into),
3131 '-' => self
3132 .builder
3133 .build_int_sub(lhs, rhs, name)
3134 .map_err(Into::into),
3135 '*' => self
3136 .builder
3137 .build_int_mul(lhs, rhs, name)
3138 .map_err(Into::into),
3139 '/' => self
3140 .builder
3141 .build_int_signed_div(lhs, rhs, name)
3142 .map_err(Into::into),
3143 '%' => self
3144 .builder
3145 .build_int_signed_rem(lhs, rhs, name)
3146 .map_err(Into::into),
3147 _ => Err(anyhow!("unknown integer binary op '{}'", binop.op)),
3148 }
3149 }
3150 _ => Err(anyhow!("unsupported integer expression '{}'", expr)),
3151 }
3152 }
3153
3154 fn clear(&mut self) {
3155 self.variables.clear();
3156 self.fn_value_opt = None;
3158 self.tensor_ptr_opt = None;
3159 }
3160
3161 fn build_dep_call(
3162 &mut self,
3163 dep_fn: FunctionValue<'ctx>,
3164 call_name: &str,
3165 barrier_start: u64,
3166 total_barriers: u64,
3167 ) -> Result<()> {
3168 let t = self
3169 .build_load(self.real_type, *self.get_param("t"), "t")?
3170 .into_float_value();
3171 let u = *self.get_param("u");
3172 let data = *self.get_param("data");
3173 let thread_id = self
3174 .build_load(self.int_type, *self.get_param("thread_id"), "thread_id")?
3175 .into_int_value();
3176 let thread_dim = self
3177 .build_load(self.int_type, *self.get_param("thread_dim"), "thread_dim")?
3178 .into_int_value();
3179 let barrier_start = self.int_type.const_int(barrier_start, false);
3180 let total_barriers = self.int_type.const_int(total_barriers, false);
3181 self.builder.build_call(
3182 dep_fn,
3183 &[
3184 t.into(),
3185 u.into(),
3186 data.into(),
3187 thread_id.into(),
3188 thread_dim.into(),
3189 barrier_start.into(),
3190 total_barriers.into(),
3191 ],
3192 call_name,
3193 )?;
3194 Ok(())
3195 }
3196
3197 fn ensure_time_dep_fn<'m>(
3198 &mut self,
3199 model: &'m DiscreteModel,
3200 code: Option<&str>,
3201 ) -> Result<FunctionValue<'ctx>> {
3202 if let Some(function) = self.module.get_function("calc_time_dep") {
3203 return Ok(function);
3204 }
3205 self.compile_dep_defns(model, "calc_time_dep", model.time_dep_defns(), code)
3206 }
3207
3208 fn ensure_state_dep_fn<'m>(
3209 &mut self,
3210 model: &'m DiscreteModel,
3211 code: Option<&str>,
3212 ) -> Result<FunctionValue<'ctx>> {
3213 if let Some(function) = self.module.get_function("calc_state_dep") {
3214 return Ok(function);
3215 }
3216 self.compile_dep_defns(model, "calc_state_dep", model.state_dep_defns(), code)
3217 }
3218
3219 fn ensure_state_dep_post_f_fn<'m>(
3220 &mut self,
3221 model: &'m DiscreteModel,
3222 code: Option<&str>,
3223 ) -> Result<FunctionValue<'ctx>> {
3224 if let Some(function) = self.module.get_function("calc_state_dep_post_f") {
3225 return Ok(function);
3226 }
3227 self.compile_dep_defns(
3228 model,
3229 "calc_state_dep_post_f",
3230 model.state_dep_post_f_defns(),
3231 code,
3232 )
3233 }
3234
3235 fn function_arg_alloca(&mut self, name: &str, arg: BasicValueEnum<'ctx>) -> PointerValue<'ctx> {
3236 match arg {
3237 BasicValueEnum::PointerValue(v) => v,
3238 BasicValueEnum::FloatValue(v) => {
3239 let alloca = self
3240 .create_entry_block_builder()
3241 .build_alloca(arg.get_type(), name)
3242 .unwrap();
3243 self.builder.build_store(alloca, v).unwrap();
3244 alloca
3245 }
3246 BasicValueEnum::IntValue(v) => {
3247 let alloca = self
3248 .create_entry_block_builder()
3249 .build_alloca(arg.get_type(), name)
3250 .unwrap();
3251 self.builder.build_store(alloca, v).unwrap();
3252 alloca
3253 }
3254 _ => unreachable!(),
3255 }
3256 }
3257
3258 pub fn compile_set_u0<'m>(
3259 &mut self,
3260 model: &'m DiscreteModel,
3261 code: Option<&str>,
3262 ) -> Result<FunctionValue<'ctx>> {
3263 self.clear();
3264 let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"];
3265 let function = self.add_function(
3266 "set_u0",
3267 fn_arg_names,
3268 &[
3269 self.real_ptr_type.into(),
3270 self.real_ptr_type.into(),
3271 self.int_type.into(),
3272 self.int_type.into(),
3273 ],
3274 None,
3275 false,
3276 );
3277
3278 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3280 let noalign = self.context.create_enum_attribute(alias_id, 0);
3281 for i in &[0, 1] {
3282 function.add_attribute(AttributeLoc::Param(*i), noalign);
3283 }
3284
3285 let _basic_block = self.start_function(function, code);
3286
3287 for (i, arg) in function.get_param_iter().enumerate() {
3288 let name = fn_arg_names[i];
3289 let alloca = self.function_arg_alloca(name, arg);
3290 self.insert_param(name, alloca);
3291 }
3292
3293 self.insert_data(model);
3294 self.insert_indices();
3295
3296 let mut nbarriers = 0;
3297 let total_barriers = (model.input_dep_defns().len() + 1) as u64;
3298 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3299 #[allow(clippy::explicit_counter_loop)]
3300 for a in model.input_dep_defns() {
3301 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3302 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3303 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3304 nbarriers += 1;
3305 }
3306
3307 self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")), code)?;
3308 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3309 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3310
3311 self.builder.build_return(None)?;
3312
3313 if function.verify(true) {
3314 Ok(function)
3315 } else {
3316 function.print_to_stderr();
3317 unsafe {
3318 function.delete();
3319 }
3320 Err(anyhow!("Invalid generated function."))
3321 }
3322 }
3323
3324 pub fn compile_calc_out<'m>(
3325 &mut self,
3326 model: &'m DiscreteModel,
3327 include_constants: bool,
3328 code: Option<&str>,
3329 ) -> Result<FunctionValue<'ctx>> {
3330 let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3331 let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3332 let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
3333 self.clear();
3334 let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"];
3335 let function_name = if include_constants {
3336 "calc_out_full"
3337 } else {
3338 "calc_out"
3339 };
3340 let function = self.add_function(
3341 function_name,
3342 fn_arg_names,
3343 &[
3344 self.real_type.into(),
3345 self.real_ptr_type.into(),
3346 self.real_ptr_type.into(),
3347 self.real_ptr_type.into(),
3348 self.int_type.into(),
3349 self.int_type.into(),
3350 ],
3351 None,
3352 false,
3353 );
3354
3355 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3357 let noalign = self.context.create_enum_attribute(alias_id, 0);
3358 for i in &[1, 2, 3] {
3359 function.add_attribute(AttributeLoc::Param(*i), noalign);
3360 }
3361
3362 let _basic_block = self.start_function(function, code);
3363
3364 for (i, arg) in function.get_param_iter().enumerate() {
3365 let name = fn_arg_names[i];
3366 let alloca = self.function_arg_alloca(name, arg);
3367 self.insert_param(name, alloca);
3368 }
3369
3370 self.insert_state(model.state());
3371 self.insert_data(model);
3372 self.insert_indices();
3373
3374 if let Some(out) = model.out() {
3380 let mut nbarriers = 0;
3381 let mut total_barriers = (model.time_dep_defns().len()
3382 + model.state_dep_defns().len()
3383 + model.state_dep_post_f_defns().len()
3384 + 1) as u64;
3385 if include_constants {
3386 total_barriers += model.input_dep_defns().len() as u64;
3387 }
3388 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3389 if include_constants {
3390 for tensor in model.input_dep_defns() {
3392 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3393 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3394 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3395 nbarriers += 1;
3396 }
3397 }
3398
3399 if !model.time_dep_defns().is_empty() {
3401 self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3402 nbarriers += model.time_dep_defns().len() as u64;
3403 }
3404
3405 if !model.state_dep_defns().is_empty() {
3407 self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3408 nbarriers += model.state_dep_defns().len() as u64;
3409 }
3410 if !model.state_dep_post_f_defns().is_empty() {
3411 self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
3412 nbarriers += model.state_dep_post_f_defns().len() as u64;
3413 }
3414
3415 self.jit_compile_tensor(out, Some(*self.get_var(model.out().unwrap())), code)?;
3416 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3417 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3418 }
3419 self.builder.build_return(None)?;
3420
3421 if function.verify(true) {
3422 Ok(function)
3423 } else {
3424 function.print_to_stderr();
3425 unsafe {
3426 function.delete();
3427 }
3428 Err(anyhow!("Invalid generated function."))
3429 }
3430 }
3431
3432 fn compile_dep_defns<'m>(
3433 &mut self,
3434 model: &'m DiscreteModel,
3435 fn_name: &str,
3436 tensors: &[Tensor<'m>],
3437 code: Option<&str>,
3438 ) -> Result<FunctionValue<'ctx>> {
3439 self.clear();
3440 let fn_arg_names = &[
3441 "t",
3442 "u",
3443 "data",
3444 "thread_id",
3445 "thread_dim",
3446 "barrier_start",
3447 "total_barriers",
3448 ];
3449 let function = self.add_function(
3450 fn_name,
3451 fn_arg_names,
3452 &[
3453 self.real_type.into(),
3454 self.real_ptr_type.into(),
3455 self.real_ptr_type.into(),
3456 self.int_type.into(),
3457 self.int_type.into(),
3458 self.int_type.into(),
3459 self.int_type.into(),
3460 ],
3461 None,
3462 false,
3463 );
3464
3465 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3467 let noalign = self.context.create_enum_attribute(alias_id, 0);
3468 for i in &[1, 2] {
3469 function.add_attribute(AttributeLoc::Param(*i), noalign);
3470 }
3471
3472 let _basic_block = self.start_function(function, code);
3473
3474 for (i, arg) in function.get_param_iter().enumerate() {
3475 let name = fn_arg_names[i];
3476 let alloca = self.function_arg_alloca(name, arg);
3477 self.insert_param(name, alloca);
3478 }
3479
3480 self.insert_state(model.state());
3481 self.insert_data(model);
3482 self.insert_indices();
3483
3484 let barrier_start = self
3485 .build_load(
3486 self.int_type,
3487 *self.get_param("barrier_start"),
3488 "barrier_start",
3489 )?
3490 .into_int_value();
3491 let total_barriers = self
3492 .build_load(
3493 self.int_type,
3494 *self.get_param("total_barriers"),
3495 "total_barriers",
3496 )?
3497 .into_int_value();
3498 let one = self.int_type.const_int(1, false);
3499
3500 for (index, tensor) in tensors.iter().enumerate() {
3501 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3502 let index = self.int_type.const_int(index as u64, false);
3503 let barrier_num_base =
3504 self.builder
3505 .build_int_add(barrier_start, index, "barrier_num_base")?;
3506 let barrier_num = self
3507 .builder
3508 .build_int_add(barrier_num_base, one, "barrier_num")?;
3509 self.jit_compile_call_barrier(barrier_num, total_barriers);
3510 }
3511
3512 self.builder.build_return(None)?;
3513
3514 if function.verify(true) {
3515 Ok(function)
3516 } else {
3517 function.print_to_stderr();
3518 unsafe {
3519 function.delete();
3520 }
3521 Err(anyhow!("Invalid generated function."))
3522 }
3523 }
3524
3525 pub fn compile_calc_stop<'m>(
3526 &mut self,
3527 model: &'m DiscreteModel,
3528 include_constants: bool,
3529 code: Option<&str>,
3530 ) -> Result<FunctionValue<'ctx>> {
3531 let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3532 let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3533 let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
3534 self.clear();
3535 let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
3536 let function_name = if include_constants {
3537 "calc_stop_full"
3538 } else {
3539 "calc_stop"
3540 };
3541 let function = self.add_function(
3542 function_name,
3543 fn_arg_names,
3544 &[
3545 self.real_type.into(),
3546 self.real_ptr_type.into(),
3547 self.real_ptr_type.into(),
3548 self.real_ptr_type.into(),
3549 self.int_type.into(),
3550 self.int_type.into(),
3551 ],
3552 None,
3553 false,
3554 );
3555
3556 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3558 let noalign = self.context.create_enum_attribute(alias_id, 0);
3559 for i in &[1, 2, 3] {
3560 function.add_attribute(AttributeLoc::Param(*i), noalign);
3561 }
3562
3563 let _basic_block = self.start_function(function, code);
3564
3565 for (i, arg) in function.get_param_iter().enumerate() {
3566 let name = fn_arg_names[i];
3567 let alloca = self.function_arg_alloca(name, arg);
3568 self.insert_param(name, alloca);
3569 }
3570
3571 self.insert_state(model.state());
3572 self.insert_data(model);
3573 self.insert_indices();
3574
3575 if let Some(stop) = model.stop() {
3576 let mut nbarriers = 0;
3578 let mut total_barriers = (model.time_dep_defns().len()
3579 + model.state_dep_defns().len()
3580 + model.state_dep_post_f_defns().len()
3581 + 1) as u64;
3582 if include_constants {
3583 total_barriers += model.input_dep_defns().len() as u64;
3584 }
3585 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3586 if include_constants {
3587 for tensor in model.input_dep_defns() {
3589 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3590 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3591 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3592 nbarriers += 1;
3593 }
3594 }
3595 if !model.time_dep_defns().is_empty() {
3596 self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3597 nbarriers += model.time_dep_defns().len() as u64;
3598 }
3599
3600 if !model.state_dep_defns().is_empty() {
3602 self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3603 nbarriers += model.state_dep_defns().len() as u64;
3604 }
3605 if !model.state_dep_post_f_defns().is_empty() {
3606 self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
3607 nbarriers += model.state_dep_post_f_defns().len() as u64;
3608 }
3609
3610 let res_ptr = self.get_param("root");
3611 self.jit_compile_tensor(stop, Some(*res_ptr), code)?;
3612 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3613 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3614 }
3615 self.builder.build_return(None)?;
3616
3617 if function.verify(true) {
3618 Ok(function)
3619 } else {
3620 function.print_to_stderr();
3621 unsafe {
3622 function.delete();
3623 }
3624 Err(anyhow!("Invalid generated function."))
3625 }
3626 }
3627
3628 pub fn compile_reset<'m>(
3629 &mut self,
3630 model: &'m DiscreteModel,
3631 include_constants: bool,
3632 code: Option<&str>,
3633 ) -> Result<FunctionValue<'ctx>> {
3634 let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3635 let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3636 let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
3637 self.clear();
3638 let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"];
3639 let function_name = if include_constants {
3640 "reset_full"
3641 } else {
3642 "reset"
3643 };
3644 let function = self.add_function(
3645 function_name,
3646 fn_arg_names,
3647 &[
3648 self.real_type.into(),
3649 self.real_ptr_type.into(),
3650 self.real_ptr_type.into(),
3651 self.real_ptr_type.into(),
3652 self.int_type.into(),
3653 self.int_type.into(),
3654 ],
3655 None,
3656 false,
3657 );
3658
3659 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3660 let noalign = self.context.create_enum_attribute(alias_id, 0);
3661 for i in &[1, 2, 3] {
3662 function.add_attribute(AttributeLoc::Param(*i), noalign);
3663 }
3664
3665 let _basic_block = self.start_function(function, code);
3666
3667 for (i, arg) in function.get_param_iter().enumerate() {
3668 let name = fn_arg_names[i];
3669 let alloca = self.function_arg_alloca(name, arg);
3670 self.insert_param(name, alloca);
3671 }
3672
3673 self.insert_state(model.state());
3674 self.insert_data(model);
3675 self.insert_indices();
3676
3677 if let Some(reset) = model.reset() {
3678 let mut nbarriers = 0;
3679 let mut total_barriers = (model.time_dep_defns().len()
3680 + model.state_dep_defns().len()
3681 + model.state_dep_post_f_defns().len()
3682 + 1) as u64;
3683 if include_constants {
3684 total_barriers += model.input_dep_defns().len() as u64;
3685 }
3686 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3687 if include_constants {
3688 for tensor in model.input_dep_defns() {
3690 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3691 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3692 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3693 nbarriers += 1;
3694 }
3695 }
3696 if !model.time_dep_defns().is_empty() {
3697 self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3698 nbarriers += model.time_dep_defns().len() as u64;
3699 }
3700
3701 if !model.state_dep_defns().is_empty() {
3702 self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3703 nbarriers += model.state_dep_defns().len() as u64;
3704 }
3705 if !model.state_dep_post_f_defns().is_empty() {
3706 self.build_dep_call(state_dep_post_f_fn, "state_dep", nbarriers, total_barriers)?;
3707 nbarriers += model.state_dep_post_f_defns().len() as u64;
3708 }
3709
3710 let res_ptr = self.get_param("reset");
3711 self.jit_compile_tensor(reset, Some(*res_ptr), code)?;
3712 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3713 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3714 }
3715 self.builder.build_return(None)?;
3716
3717 if function.verify(true) {
3718 Ok(function)
3719 } else {
3720 function.print_to_stderr();
3721 unsafe {
3722 function.delete();
3723 }
3724 Err(anyhow!("Invalid generated function."))
3725 }
3726 }
3727
3728 pub fn compile_rhs<'m>(
3729 &mut self,
3730 model: &'m DiscreteModel,
3731 include_constants: bool,
3732 code: Option<&str>,
3733 ) -> Result<FunctionValue<'ctx>> {
3734 let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
3735 let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
3736 self.clear();
3737 let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"];
3738 let function_name = if include_constants { "rhs_full" } else { "rhs" };
3739 let function = self.add_function(
3740 function_name,
3741 fn_arg_names,
3742 &[
3743 self.real_type.into(),
3744 self.real_ptr_type.into(),
3745 self.real_ptr_type.into(),
3746 self.real_ptr_type.into(),
3747 self.int_type.into(),
3748 self.int_type.into(),
3749 ],
3750 None,
3751 false,
3752 );
3753
3754 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3756 let noalign = self.context.create_enum_attribute(alias_id, 0);
3757 for i in &[1, 2, 3] {
3758 function.add_attribute(AttributeLoc::Param(*i), noalign);
3759 }
3760
3761 let _basic_block = self.start_function(function, code);
3762
3763 for (i, arg) in function.get_param_iter().enumerate() {
3764 let name = fn_arg_names[i];
3765 let alloca = self.function_arg_alloca(name, arg);
3766 self.insert_param(name, alloca);
3767 }
3768
3769 self.insert_state(model.state());
3770 self.insert_data(model);
3771 self.insert_indices();
3772
3773 let mut nbarriers = 0;
3774 let mut total_barriers =
3775 (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64;
3776 if include_constants {
3777 total_barriers += model.input_dep_defns().len() as u64;
3778 for tensor in model.input_dep_defns() {
3780 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3781 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3782 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3783 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3784 nbarriers += 1;
3785 }
3786 }
3787
3788 if !model.time_dep_defns().is_empty() {
3790 self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
3791 nbarriers += model.time_dep_defns().len() as u64;
3792 }
3793
3794 if !model.state_dep_defns().is_empty() {
3795 self.build_dep_call(state_dep_fn, "state_dep", nbarriers, total_barriers)?;
3796 nbarriers += model.state_dep_defns().len() as u64;
3797 }
3798
3799 let res_ptr = self.get_param("rr");
3801 self.jit_compile_tensor(model.rhs(), Some(*res_ptr), code)?;
3802 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3803 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3804 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3805
3806 self.builder.build_return(None)?;
3807
3808 if function.verify(true) {
3809 Ok(function)
3810 } else {
3811 function.print_to_stderr();
3812 unsafe {
3813 function.delete();
3814 }
3815 Err(anyhow!("Invalid generated function."))
3816 }
3817 }
3818
3819 pub fn compile_mass<'m>(
3820 &mut self,
3821 model: &'m DiscreteModel,
3822 code: Option<&str>,
3823 ) -> Result<FunctionValue<'ctx>> {
3824 self.clear();
3825 let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"];
3826 let function = self.add_function(
3827 "mass",
3828 fn_arg_names,
3829 &[
3830 self.real_type.into(),
3831 self.real_ptr_type.into(),
3832 self.real_ptr_type.into(),
3833 self.real_ptr_type.into(),
3834 self.int_type.into(),
3835 self.int_type.into(),
3836 ],
3837 None,
3838 false,
3839 );
3840
3841 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3843 let noalign = self.context.create_enum_attribute(alias_id, 0);
3844 for i in &[1, 2, 3] {
3845 function.add_attribute(AttributeLoc::Param(*i), noalign);
3846 }
3847
3848 let _basic_block = self.start_function(function, code);
3849
3850 for (i, arg) in function.get_param_iter().enumerate() {
3851 let name = fn_arg_names[i];
3852 let alloca = self.function_arg_alloca(name, arg);
3853 self.insert_param(name, alloca);
3854 }
3855
3856 if model.state_dot().is_some() && model.lhs().is_some() {
3858 let state_dot = model.state_dot().unwrap();
3859 let lhs = model.lhs().unwrap();
3860
3861 self.insert_dot_state(state_dot);
3862 self.insert_data(model);
3863 self.insert_indices();
3864
3865 let mut nbarriers = 0;
3867 let total_barriers =
3868 (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64;
3869 let total_barriers_val = self.int_type.const_int(total_barriers, false);
3870 for tensor in model.time_dep_defns() {
3871 self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
3872 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3873 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3874 nbarriers += 1;
3875 }
3876
3877 for a in model.dstate_dep_defns() {
3878 self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?;
3879 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3880 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3881 nbarriers += 1;
3882 }
3883
3884 let res_ptr = self.get_param("rr");
3886 self.jit_compile_tensor(lhs, Some(*res_ptr), code)?;
3887 let barrier_num = self.int_type.const_int(nbarriers + 1, false);
3888 self.jit_compile_call_barrier(barrier_num, total_barriers_val);
3889 }
3890
3891 self.builder.build_return(None)?;
3892
3893 if function.verify(true) {
3894 Ok(function)
3895 } else {
3896 function.print_to_stderr();
3897 unsafe {
3898 function.delete();
3899 }
3900 Err(anyhow!("Invalid generated function."))
3901 }
3902 }
3903
3904 pub fn compile_gradient(
3905 &mut self,
3906 original_function: FunctionValue<'ctx>,
3907 args_type: &[CompileGradientArgType],
3908 mode: CompileMode,
3909 fn_name: &str,
3910 ) -> Result<FunctionValue<'ctx>> {
3911 self.clear();
3912
3913 let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
3915
3916 let orig_fn_type_ptr = Self::fn_pointer_type(self.context, original_function.get_type());
3917
3918 let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
3919 let mut start_param_index: Vec<u32> = Vec::new();
3920 let mut ptr_arg_indices: Vec<u32> = Vec::new();
3921 for (i, arg) in original_function.get_param_iter().enumerate() {
3922 let start_index = u32::try_from(fn_type.len()).unwrap();
3923 start_param_index.push(start_index);
3924 let arg_type = arg.get_type();
3925 fn_type.push(arg_type.into());
3926
3927 enzyme_fn_type.push(self.int_type.into());
3929 enzyme_fn_type.push(arg.get_type().into());
3930
3931 if arg_type.is_pointer_type() {
3932 ptr_arg_indices.push(start_index);
3933 }
3934
3935 match args_type[i] {
3936 CompileGradientArgType::Dup | CompileGradientArgType::DupNoNeed => {
3937 fn_type.push(arg.get_type().into());
3938 enzyme_fn_type.push(arg.get_type().into());
3939 if arg_type.is_pointer_type() {
3940 ptr_arg_indices.push(start_index + 1);
3941 }
3942 }
3943 CompileGradientArgType::Const => {}
3944 }
3945 }
3946 let fn_arg_names = fn_type
3947 .iter()
3948 .enumerate()
3949 .map(|(i, _)| format!("arg{}", i))
3950 .collect::<Vec<String>>();
3951 let fn_arg_names_ref = fn_arg_names
3952 .iter()
3953 .map(|s| s.as_str())
3954 .collect::<Vec<&str>>();
3955 let function = self.add_function(fn_name, &fn_arg_names_ref, &fn_type, None, false);
3956
3957 let alias_id = Attribute::get_named_enum_kind_id("noalias");
3959 let noalign = self.context.create_enum_attribute(alias_id, 0);
3960 for i in ptr_arg_indices {
3961 function.add_attribute(AttributeLoc::Param(i), noalign);
3962 }
3963
3964 let _basic_block = self.start_function(function, None);
3965
3966 let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = Vec::new();
3967 let mut input_activity = Vec::new();
3968 let mut arg_trees = Vec::new();
3969 for (i, arg) in original_function.get_param_iter().enumerate() {
3970 let param_index = start_param_index[i];
3971 let fn_arg = function.get_nth_param(param_index).unwrap();
3972
3973 let concrete_type = match arg.get_type() {
3976 BasicTypeEnum::PointerType(_t) => CConcreteType_DT_Pointer,
3977 BasicTypeEnum::FloatType(_t) => match self.diffsl_real_type {
3978 RealType::F32 => CConcreteType_DT_Float,
3979 RealType::F64 => CConcreteType_DT_Double,
3980 },
3981 BasicTypeEnum::IntType(_) => CConcreteType_DT_Integer,
3982 _ => panic!("unsupported type"),
3983 };
3984 let new_tree = unsafe {
3985 EnzymeNewTypeTreeCT(
3986 concrete_type,
3987 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
3988 )
3989 };
3990 unsafe { EnzymeTypeTreeOnlyEq(new_tree, -1) };
3991
3992 if concrete_type == CConcreteType_DT_Pointer {
3994 let inner_concrete_type = match self.diffsl_real_type {
3995 RealType::F32 => CConcreteType_DT_Float,
3996 RealType::F64 => CConcreteType_DT_Double,
3997 };
3998 let inner_new_tree = unsafe {
3999 EnzymeNewTypeTreeCT(
4000 inner_concrete_type,
4001 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
4002 )
4003 };
4004 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
4005 unsafe { EnzymeTypeTreeOnlyEq(inner_new_tree, -1) };
4006 unsafe { EnzymeMergeTypeTree(new_tree, inner_new_tree) };
4007 }
4008
4009 arg_trees.push(new_tree);
4010 match args_type[i] {
4011 CompileGradientArgType::Dup => {
4012 enzyme_fn_args.push(fn_arg.into());
4014
4015 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
4017 enzyme_fn_args.push(fn_darg.into());
4018
4019 input_activity.push(CDIFFE_TYPE_DFT_DUP_ARG);
4020 }
4021 CompileGradientArgType::DupNoNeed => {
4022 enzyme_fn_args.push(fn_arg.into());
4024
4025 let fn_darg = function.get_nth_param(param_index + 1).unwrap();
4027 enzyme_fn_args.push(fn_darg.into());
4028
4029 input_activity.push(CDIFFE_TYPE_DFT_DUP_NONEED);
4030 }
4031 CompileGradientArgType::Const => {
4032 enzyme_fn_args.push(fn_arg.into());
4034
4035 input_activity.push(CDIFFE_TYPE_DFT_CONSTANT);
4036 }
4037 }
4038 }
4039 let ret_primary_ret = false;
4041 let diff_ret = false;
4042 let ret_activity = CDIFFE_TYPE_DFT_CONSTANT;
4043 let ret_tree = unsafe {
4044 EnzymeNewTypeTreeCT(
4045 CConcreteType_DT_Anything,
4046 self.context.as_ctx_ref() as *mut LLVMOpaqueContext,
4047 )
4048 };
4049
4050 let fnc_opt_base = true;
4052 let logic_ref: EnzymeLogicRef = unsafe { CreateEnzymeLogic(fnc_opt_base as u8) };
4053
4054 let kv_tmp = IntList {
4055 data: std::ptr::null_mut(),
4056 size: 0,
4057 };
4058 let mut known_values = vec![kv_tmp; input_activity.len()];
4059
4060 let fn_type_info = CFnTypeInfo {
4061 Arguments: arg_trees.as_mut_ptr(),
4062 Return: ret_tree,
4063 KnownValues: known_values.as_mut_ptr(),
4064 };
4065
4066 let type_analysis: EnzymeTypeAnalysisRef =
4067 unsafe { CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0) };
4068
4069 let mut args_uncacheable = vec![0; arg_trees.len()];
4070
4071 let enzyme_function = match mode {
4072 CompileMode::Forward | CompileMode::ForwardSens => unsafe {
4073 EnzymeCreateForwardDiff(
4074 logic_ref, std::ptr::null_mut(),
4076 std::ptr::null_mut(),
4077 original_function.as_value_ref(),
4078 ret_activity, input_activity.as_mut_ptr(),
4080 input_activity.len(), type_analysis, ret_primary_ret as u8,
4083 CDerivativeMode_DEM_ForwardMode, 1, 0, 0, 1, std::ptr::null_mut(), fn_type_info, 1, args_uncacheable.as_mut_ptr(), args_uncacheable.len(), std::ptr::null_mut(), )
4095 },
4096 CompileMode::Reverse | CompileMode::ReverseSens => {
4097 let mut call_enzyme = || unsafe {
4098 EnzymeCreatePrimalAndGradient(
4099 logic_ref,
4100 std::ptr::null_mut(),
4101 std::ptr::null_mut(),
4102 original_function.as_value_ref(),
4103 ret_activity,
4104 input_activity.as_mut_ptr(),
4105 input_activity.len(),
4106 type_analysis,
4107 ret_primary_ret as u8,
4108 diff_ret as u8,
4109 CDerivativeMode_DEM_ReverseModeCombined,
4110 0,
4111 0, 1,
4113 1,
4114 std::ptr::null_mut(),
4115 0, fn_type_info,
4117 0, args_uncacheable.as_mut_ptr(),
4119 args_uncacheable.len(),
4120 std::ptr::null_mut(),
4121 if self.threaded { 1 } else { 0 }, )
4123 };
4124 if self.threaded {
4125 let _lock = my_mutex.lock().unwrap();
4127 let barrier_string = CString::new("barrier").unwrap();
4128 unsafe {
4129 EnzymeRegisterCallHandler(
4130 barrier_string.as_ptr(),
4131 Some(fwd_handler),
4132 Some(rev_handler),
4133 )
4134 };
4135 let ret = call_enzyme();
4136 unsafe { EnzymeRegisterCallHandler(barrier_string.as_ptr(), None, None) };
4138 ret
4139 } else {
4140 call_enzyme()
4141 }
4142 }
4143 };
4144
4145 unsafe { FreeEnzymeLogic(logic_ref) };
4147 unsafe { FreeTypeAnalysis(type_analysis) };
4148 unsafe { EnzymeFreeTypeTree(ret_tree) };
4149 for tree in arg_trees {
4150 unsafe { EnzymeFreeTypeTree(tree) };
4151 }
4152
4153 let enzyme_function =
4155 unsafe { FunctionValue::new(enzyme_function as LLVMValueRef) }.unwrap();
4156 self.builder
4157 .build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;
4158
4159 self.builder.build_return(None)?;
4161
4162 if function.verify(true) {
4163 Ok(function)
4164 } else {
4165 function.print_to_stderr();
4166 enzyme_function.print_to_stderr();
4167 unsafe {
4168 function.delete();
4169 }
4170 Err(anyhow!("Invalid generated function."))
4171 }
4172 }
4173
4174 pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
4175 self.clear();
4176 let fn_arg_names = &[
4177 "states",
4178 "inputs",
4179 "outputs",
4180 "data",
4181 "stop",
4182 "has_mass",
4183 "has_reset",
4184 ];
4185 let function = self.add_function(
4186 "get_dims",
4187 fn_arg_names,
4188 &[
4189 self.int_ptr_type.into(),
4190 self.int_ptr_type.into(),
4191 self.int_ptr_type.into(),
4192 self.int_ptr_type.into(),
4193 self.int_ptr_type.into(),
4194 self.int_ptr_type.into(),
4195 self.int_ptr_type.into(),
4196 ],
4197 None,
4198 false,
4199 );
4200 let _block = self.start_function(function, None);
4201
4202 for (i, arg) in function.get_param_iter().enumerate() {
4203 let name = fn_arg_names[i];
4204 let alloca = self.function_arg_alloca(name, arg);
4205 self.insert_param(name, alloca);
4206 }
4207
4208 self.insert_indices();
4209
4210 let number_of_states = model.state().nnz() as u64;
4211 let number_of_inputs = model.input().map(|inp| inp.nnz()).unwrap_or(0) as u64;
4212 let number_of_outputs = match model.out() {
4213 Some(out) => out.nnz() as u64,
4214 None => 0,
4215 };
4216 let number_of_stop = if let Some(stop) = model.stop() {
4217 stop.nnz() as u64
4218 } else {
4219 0
4220 };
4221 let has_mass = match model.lhs().is_some() {
4222 true => 1u64,
4223 false => 0u64,
4224 };
4225 let has_reset = match model.reset().is_some() {
4226 true => 1u64,
4227 false => 0u64,
4228 };
4229 let data_len = self.layout.data().len() as u64;
4230 self.builder.build_store(
4231 *self.get_param("states"),
4232 self.int_type.const_int(number_of_states, false),
4233 )?;
4234 self.builder.build_store(
4235 *self.get_param("inputs"),
4236 self.int_type.const_int(number_of_inputs, false),
4237 )?;
4238 self.builder.build_store(
4239 *self.get_param("outputs"),
4240 self.int_type.const_int(number_of_outputs, false),
4241 )?;
4242 self.builder.build_store(
4243 *self.get_param("data"),
4244 self.int_type.const_int(data_len, false),
4245 )?;
4246 self.builder.build_store(
4247 *self.get_param("stop"),
4248 self.int_type.const_int(number_of_stop, false),
4249 )?;
4250 self.builder.build_store(
4251 *self.get_param("has_mass"),
4252 self.int_type.const_int(has_mass, false),
4253 )?;
4254 self.builder.build_store(
4255 *self.get_param("has_reset"),
4256 self.int_type.const_int(has_reset, false),
4257 )?;
4258 self.builder.build_return(None)?;
4259
4260 if function.verify(true) {
4261 Ok(function)
4262 } else {
4263 function.print_to_stderr();
4264 unsafe {
4265 function.delete();
4266 }
4267 Err(anyhow!("Invalid generated function."))
4268 }
4269 }
4270
4271 pub fn compile_get_tensor(
4272 &mut self,
4273 model: &DiscreteModel,
4274 name: &str,
4275 ) -> Result<FunctionValue<'ctx>> {
4276 self.clear();
4277 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
4278 let function_name = format!("get_tensor_{name}");
4279 let fn_arg_names = &["data", "tensor_data", "tensor_size"];
4280 let function = self.add_function(
4281 function_name.as_str(),
4282 fn_arg_names,
4283 &[
4284 self.real_ptr_type.into(),
4285 real_ptr_ptr_type.into(),
4286 self.int_ptr_type.into(),
4287 ],
4288 None,
4289 false,
4290 );
4291 let _basic_block = self.start_function(function, None);
4292
4293 for (i, arg) in function.get_param_iter().enumerate() {
4294 let name = fn_arg_names[i];
4295 let alloca = self.function_arg_alloca(name, arg);
4296 self.insert_param(name, alloca);
4297 }
4298
4299 self.insert_data(model);
4300 let ptr = self.get_param(name);
4301 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
4302 let tensor_size_value = self.int_type.const_int(tensor_size, false);
4303 self.builder
4304 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
4305 self.builder
4306 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
4307 self.builder.build_return(None)?;
4308
4309 if function.verify(true) {
4310 Ok(function)
4311 } else {
4312 function.print_to_stderr();
4313 unsafe {
4314 function.delete();
4315 }
4316 Err(anyhow!("Invalid generated function."))
4317 }
4318 }
4319
4320 pub fn compile_get_constant(
4321 &mut self,
4322 model: &DiscreteModel,
4323 name: &str,
4324 ) -> Result<FunctionValue<'ctx>> {
4325 self.clear();
4326 let real_ptr_ptr_type = Self::pointer_type(self.context, self.real_ptr_type.into());
4327 let function_name = format!("get_constant_{name}");
4328 let fn_arg_names = &["tensor_data", "tensor_size"];
4329 let function = self.add_function(
4330 function_name.as_str(),
4331 fn_arg_names,
4332 &[real_ptr_ptr_type.into(), self.int_ptr_type.into()],
4333 None,
4334 false,
4335 );
4336 let _basic_block = self.start_function(function, None);
4337
4338 for (i, arg) in function.get_param_iter().enumerate() {
4339 let name = fn_arg_names[i];
4340 let alloca = self.function_arg_alloca(name, arg);
4341 self.insert_param(name, alloca);
4342 }
4343
4344 self.insert_constants(model);
4345 let ptr = self.get_param(name);
4346 let tensor_size = self.layout.get_layout(name).unwrap().nnz() as u64;
4347 let tensor_size_value = self.int_type.const_int(tensor_size, false);
4348 self.builder
4349 .build_store(*self.get_param("tensor_data"), ptr.as_basic_value_enum())?;
4350 self.builder
4351 .build_store(*self.get_param("tensor_size"), tensor_size_value)?;
4352 self.builder.build_return(None)?;
4353
4354 if function.verify(true) {
4355 Ok(function)
4356 } else {
4357 function.print_to_stderr();
4358 unsafe {
4359 function.delete();
4360 }
4361 Err(anyhow!("Invalid generated function."))
4362 }
4363 }
4364
4365 pub fn compile_inputs(
4366 &mut self,
4367 model: &DiscreteModel,
4368 is_get: bool,
4369 ) -> Result<FunctionValue<'ctx>> {
4370 self.clear();
4371 let function_name = if is_get { "get_inputs" } else { "set_inputs" };
4372 let fn_arg_names: &[&str] = if is_get {
4373 &["inputs", "data"]
4374 } else {
4375 &["inputs", "data", "model_index"]
4376 };
4377 let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = if is_get {
4378 &[self.real_ptr_type.into(), self.real_ptr_type.into()]
4379 } else {
4380 &[
4381 self.real_ptr_type.into(),
4382 self.real_ptr_type.into(),
4383 self.int_type.into(),
4384 ]
4385 };
4386 let function = self.add_function(function_name, fn_arg_names, fn_arg_types, None, false);
4387 let block = self.start_function(function, None);
4388
4389 for (i, arg) in function.get_param_iter().enumerate() {
4390 let name = fn_arg_names[i];
4391 let alloca = self.function_arg_alloca(name, arg);
4392 self.insert_param(name, alloca);
4393 }
4394
4395 if !is_get {
4396 let model_index = self
4397 .build_load(self.int_type, *self.get_param("model_index"), "model_index")?
4398 .into_int_value();
4399 self.builder
4400 .build_store(self.globals.model_index.as_pointer_value(), model_index)?;
4401 }
4402
4403 if let Some(input) = model.input() {
4404 let name = input.name();
4405 self.insert_tensor(input, false);
4406 let ptr = self.get_var(input);
4407 let inputs_start_index = self.int_type.const_int(0, false);
4409 let start_index = self.int_type.const_int(0, false);
4410 let end_index = self
4411 .int_type
4412 .const_int(input.nnz().try_into().unwrap(), false);
4413
4414 let input_block = self.context.append_basic_block(function, name);
4415 self.builder.build_unconditional_branch(input_block)?;
4416 self.builder.position_at_end(input_block);
4417 let index = self.builder.build_phi(self.int_type, "i")?;
4418 index.add_incoming(&[(&start_index, block)]);
4419
4420 let curr_input_index = index.as_basic_value().into_int_value();
4422 let input_ptr =
4423 Self::get_ptr_to_index(&self.builder, self.real_type, ptr, curr_input_index, name);
4424 let curr_inputs_index =
4425 self.builder
4426 .build_int_add(inputs_start_index, curr_input_index, name)?;
4427 let inputs_ptr = Self::get_ptr_to_index(
4428 &self.builder,
4429 self.real_type,
4430 self.get_param("inputs"),
4431 curr_inputs_index,
4432 name,
4433 );
4434 if is_get {
4435 let input_value = self
4436 .build_load(self.real_type, input_ptr, name)?
4437 .into_float_value();
4438 self.builder.build_store(inputs_ptr, input_value)?;
4439 } else {
4440 let input_value = self
4441 .build_load(self.real_type, inputs_ptr, name)?
4442 .into_float_value();
4443 self.builder.build_store(input_ptr, input_value)?;
4444 }
4445
4446 let one = self.int_type.const_int(1, false);
4448 let next_index = self.builder.build_int_add(curr_input_index, one, name)?;
4449 index.add_incoming(&[(&next_index, input_block)]);
4450
4451 let loop_while =
4453 self.builder
4454 .build_int_compare(IntPredicate::ULT, next_index, end_index, name)?;
4455 let post_block = self.context.append_basic_block(function, name);
4456 self.builder
4457 .build_conditional_branch(loop_while, input_block, post_block)?;
4458 self.builder.position_at_end(post_block);
4459 }
4460 self.builder.build_return(None)?;
4461
4462 if function.verify(true) {
4463 Ok(function)
4464 } else {
4465 function.print_to_stderr();
4466 unsafe {
4467 function.delete();
4468 }
4469 Err(anyhow!("Invalid generated function."))
4470 }
4471 }
4472
4473 pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
4474 self.clear();
4475 let fn_arg_names = &["id"];
4476 let function = self.add_function(
4477 "set_id",
4478 fn_arg_names,
4479 &[self.real_ptr_type.into()],
4480 None,
4481 false,
4482 );
4483 let mut block = self.start_function(function, None);
4484
4485 for (i, arg) in function.get_param_iter().enumerate() {
4486 let name = fn_arg_names[i];
4487 let alloca = self.function_arg_alloca(name, arg);
4488 self.insert_param(name, alloca);
4489 }
4490
4491 let mut id_index = 0usize;
4492 for (blk, is_algebraic) in zip(model.state().elmts(), model.is_algebraic()) {
4493 let name = blk.name().unwrap_or("unknown");
4494 let id_start_index = self.int_type.const_int(id_index as u64, false);
4496 let blk_start_index = self.int_type.const_int(0, false);
4497 let blk_end_index = self
4498 .int_type
4499 .const_int(blk.nnz().try_into().unwrap(), false);
4500
4501 let blk_block = self.context.append_basic_block(function, name);
4502 self.builder.build_unconditional_branch(blk_block)?;
4503 self.builder.position_at_end(blk_block);
4504 let index = self.builder.build_phi(self.int_type, "i")?;
4505 index.add_incoming(&[(&blk_start_index, block)]);
4506
4507 let curr_blk_index = index.as_basic_value().into_int_value();
4509 let curr_id_index = self
4510 .builder
4511 .build_int_add(id_start_index, curr_blk_index, name)?;
4512 let id_ptr = Self::get_ptr_to_index(
4513 &self.builder,
4514 self.real_type,
4515 self.get_param("id"),
4516 curr_id_index,
4517 name,
4518 );
4519 let is_algebraic_float = if *is_algebraic { 0.0 } else { 1.0 };
4520 let is_algebraic_value = self.real_type.const_float(is_algebraic_float);
4521 self.builder.build_store(id_ptr, is_algebraic_value)?;
4522
4523 let one = self.int_type.const_int(1, false);
4525 let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
4526 index.add_incoming(&[(&next_index, blk_block)]);
4527
4528 let loop_while = self.builder.build_int_compare(
4530 IntPredicate::ULT,
4531 next_index,
4532 blk_end_index,
4533 name,
4534 )?;
4535 let post_block = self.context.append_basic_block(function, name);
4536 self.builder
4537 .build_conditional_branch(loop_while, blk_block, post_block)?;
4538 self.builder.position_at_end(post_block);
4539
4540 block = post_block;
4542 id_index += blk.nnz();
4543 }
4544 self.builder.build_return(None)?;
4545
4546 if function.verify(true) {
4547 Ok(function)
4548 } else {
4549 function.print_to_stderr();
4550 unsafe {
4551 function.delete();
4552 }
4553 Err(anyhow!("Invalid generated function."))
4554 }
4555 }
4556
4557 pub fn module(&self) -> &Module<'ctx> {
4558 &self.module
4559 }
4560}