1use std::{collections::{HashMap, HashSet}, ffi::{CString, c_void}, mem::MaybeUninit, ptr};
4
5use anyhow::{anyhow, Result};
6
7use libc::{c_int};
8use llvm_sys::{prelude::*, LLVMIntPredicate, LLVMRealPredicate, LLVMLinkage, LLVMTypeKind,
9 analysis::{LLVMVerifierFailureAction, LLVMVerifyModule},
10 initialization::LLVMInitializeCore,
11 orc2::*, orc2::lljit::*,
12 transforms::{pass_manager_builder::*}};
13use llvm_sys::target::*;
14use llvm_sys::core::*;
15use slotmap::{Key, KeyData};
16
17use crate::{codegen::{functions::_random_normal, util::unwrap_usize_constant}, cstr, cstring, parser::{self, AssignStatement, FipsType, Statement}, runtime::{BuiltinFunction, Domain, FunctionID, InteractionQuantityID, MemberData, OutOfBoundsBehavior}};
18
19use super::{
20 util::unwrap_f64_constant,
21 CallbackTarget, ThreadContext,
22 evaluate_expression, evaluate_binop, convert_to_scalar_or_array,
23 analysis::{FipsSymbolKind, SymbolTable, SimulationNode, BarrierKind},
24 llhelpers::*};
25
26const WORKER_MAIN_NAME: &str = "worker_main";
27const WORKER_MODULE_NAME: &str = "worker_module";
28type WorkerMainFunc = unsafe extern "C" fn();
29
30#[no_mangle]
31pub unsafe extern "C" fn _call2rust_handler(callback_target: u64, barrier_data: u64) {
32 let callback_target = callback_target as usize as *const CallbackTarget;
33 let barrier = KeyData::from_ffi(barrier_data).into();
34 callback_target.as_ref().unwrap().handle_call2rust(barrier);
35}
36
37#[no_mangle]
38pub unsafe extern "C" fn _interaction_handler(callback_target: u64, barrier_data: u64,
39 block_index: usize, neighbor_list_index_ret: *mut *const usize, neighbor_list_ret: *mut *const usize)
40-> () {
41 let callback_target = callback_target as usize as *const CallbackTarget;
42 let barrier = KeyData::from_ffi(barrier_data).into();
43 let (neighbor_list_index, neighbor_list) = callback_target.as_ref().unwrap().handle_interaction(barrier, block_index);
44 *neighbor_list_index_ret = neighbor_list_index;
45 *neighbor_list_ret = neighbor_list;
46}
47
48#[no_mangle]
49pub unsafe extern "C" fn _interaction_sync_handler(callback_target: u64, barrier_data: u64)
50-> () {
51 let callback_target = callback_target as usize as *const CallbackTarget;
52 let barrier = KeyData::from_ffi(barrier_data).into();
53 callback_target.as_ref().unwrap().handle_interaction_sync(barrier);
54}
55
56#[no_mangle]
57pub unsafe extern "C" fn _end_of_step(callback_target: u64) {
58 let callback_target = callback_target as usize as *const CallbackTarget;
59 callback_target.as_ref().unwrap().end_of_step();
60}
61
62#[no_mangle]
63pub unsafe extern "C" fn print_u64(x: u64) {
64 println!("{}", x);
65}
66
67#[no_mangle]
68pub unsafe extern "C" fn print_f64(x: f64) {
69 println!("{}", x);
70}
71
72const FIPS_FUNCS: &'static[&str] = &[
74 "_call2rust_handler",
75 "_interaction_handler",
76 "_interaction_sync_handler",
77 "_end_of_step",
78 "print_u64",
79 "print_f64",
80 "_random_normal", ];
82
83const SYSTEM_FUNCS: &'static[&str] = &[
84 "memset",
85 "fmod",
86 "sqrt",
87 "sin",
88 "cos",
89 "sincos",
90];
91
92pub extern "C" fn allowed_symbol_filter(ctx: *mut c_void, sym: LLVMOrcSymbolStringPoolEntryRef) -> c_int {
93 unsafe {
94 if ctx.is_null() {
95 panic!("Cannot call allowed_symbol_filter with a null context");
96 }
97
98 let allow_list: *mut LLVMOrcSymbolStringPoolEntryRef = std::mem::transmute_copy(&ctx);
99
100 let mut allowed_symbol = allow_list;
102 while !(*allowed_symbol).is_null() {
103 if sym == *allowed_symbol {
104 return 1;
105 }
106 allowed_symbol = allowed_symbol.offset(1);
107 }
108
109 return 0;
111 }
112}
113
114pub(crate) enum LLSymbolValue {
116 SimplePointer(LLVMValueRef),
119 ParticleMember {
123 base_ptr: LLVMValueRef,
124 local_ptr: Option<LLVMValueRef>
125 },
126 Function(LLFunctionSymbolValue)
129}
130
131pub(crate) struct LLFunctionSymbolValue {
132 pub(crate) function_id: FunctionID,
134 pub(crate) function: LLVMValueRef,
136 pub(crate) global_parameter_ptrs: Vec<Option<LLVMValueRef>>
138}
139
140struct InteractionValues {
142 own_pos_block_index: LLVMValueRef,
143 interaction_func: LLVMValueRef
144}
145
146pub struct CodeGenerator {
148 module_ts: LLVMOrcThreadSafeModuleRef,
150 callback_target: Box<CallbackTarget>,
152 external_symbols: Vec<String>
154}
155
156impl CodeGenerator {
157 pub(crate) fn new(thread_context: ThreadContext) -> Result<Self> {
158 let callback_target = Box::new(CallbackTarget::new(thread_context));
160 let callback_target_ptr = &*callback_target as *const CallbackTarget as u64; let thread_context = &callback_target.thread_context;
162 let particle_id = thread_context.particle_id;
164 let domain = &thread_context.executor_context.global_context.runtime.domain;
165 let timeline = thread_context.executor_context.global_context.simgraph.timelines.get(&particle_id).unwrap();
166 let particle_index = &thread_context.executor_context.global_context.runtime.particle_index;
167 let particle_store = &thread_context.executor_context.global_context.runtime.particle_store;
168 let particle_data = particle_store.get_particle(particle_id).unwrap();
169 let particle_range = thread_context.particle_range.clone();
170 let function_index = &thread_context.executor_context.global_context.runtime.function_index;
171 let global_symbols: SymbolTable<LLSymbolValue> = thread_context
173 .executor_context.global_context.global_symbols.clone().convert();
174 let particle_symbols: SymbolTable<LLSymbolValue> = timeline.particle_symbols.clone().convert();
175 let mut symbol_table = SymbolTable::new();
176 symbol_table.push_table(global_symbols);
177 symbol_table.push_table(particle_symbols);
178 let neighbor_lists = thread_context.executor_context.neighbor_lists.iter()
180 .map(|(interaction_id, neighbor_list)| {
181 let neighbor_list = neighbor_list.read().unwrap();
182 (*interaction_id, neighbor_list)
183 })
184 .collect::<HashMap<_,_>>();
185
186 unsafe {
187 let context_ts = LLVMOrcCreateNewThreadSafeContext();
189 let context = LLVMOrcThreadSafeContextGetContext(context_ts);
190 let module = LLVMModuleCreateWithNameInContext(cstring!(WORKER_MODULE_NAME), context);
191 let builder = LLVMCreateBuilderInContext(context);
192
193 let void_type = LLVMVoidTypeInContext(context);
195 let int64_type = LLVMInt64TypeInContext(context);
196 let int8_type = LLVMInt8TypeInContext(context);
197 let double_type = LLVMDoubleTypeInContext(context);
198
199 let start_index = LLVMConstInt(int64_type, particle_range.start as u64, 0);
201 let end_index = LLVMConstInt(int64_type, particle_range.end as u64, 0);
202
203 let barrier_handler_type = LLVMFunctionType(void_type, [int64_type, int64_type].as_mut_ptr(), 2, 0);
209 let call2rust_handler = LLVMAddFunction(module, cstr!("_call2rust_handler"), barrier_handler_type);
210 let interaction_handler_type = LLVMFunctionType(void_type, [int64_type, int64_type, int64_type,
211 LLVMPointerType(LLVMPointerType(int64_type, 0), 0),
212 LLVMPointerType(LLVMPointerType(int64_type, 0), 0)].as_mut_ptr(), 5, 0);
213 let interaction_handler = LLVMAddFunction(module, cstr!("_interaction_handler"), interaction_handler_type);
214 let interaction_sync_handler = LLVMAddFunction(module, cstr!("_interaction_sync_handler"), barrier_handler_type);
215 let end_of_step_handler_type = LLVMFunctionType(void_type, [int64_type].as_mut_ptr(), 1, 0);
216 let end_of_step_handler = LLVMAddFunction(module, cstr!("_end_of_step"), end_of_step_handler_type);
217
218 let callback_target_ptrptr = LLVMAddGlobal(module, int64_type, cstr!("_callback_target_ptr"));
220 let initializer = LLVMConstInt(int64_type, callback_target_ptr, 0);
221 LLVMSetGlobalConstant(callback_target_ptrptr, 1);
222 LLVMSetInitializer(callback_target_ptrptr, initializer);
223
224 let print_func_u64_type = LLVMFunctionType(void_type, [int64_type].as_mut_ptr(), 1, 0);
226 #[allow(unused_variables)]
227 let print_func_u64 = LLVMAddFunction(module, cstr!("print_u64"), print_func_u64_type);
228 let print_func_f64_type = LLVMFunctionType(void_type, [double_type].as_mut_ptr(), 1, 0);
229 #[allow(unused_variables)]
230 let print_func_f64 = LLVMAddFunction(module, cstr!("print_f64"), print_func_f64_type);
231
232 let mut external_symbols = FIPS_FUNCS.iter()
234 .chain(SYSTEM_FUNCS.iter())
235 .map(|symbol_name| symbol_name.to_string())
236 .collect::<Vec<_>>();
237 for (_, function_def) in function_index.get_functions() {
238 external_symbols.push(function_def.get_name().to_string());
239 }
240
241 for (name, symbol) in symbol_table.iter_mut() {
243 match &symbol.kind {
244 FipsSymbolKind::Constant(const_val) => {
246 let llname = format!("constant_{}", name);
247 let llval = create_global_const(module, llname, const_val.clone());
248 symbol.set_value(LLSymbolValue::SimplePointer(llval));
249 }
250 FipsSymbolKind::ParticleMember(member_id) => {
256 let member_definition = particle_index.get(particle_id).unwrap()
258 .get_member(&member_id).unwrap();
259 let member_data = particle_data.borrow_member(member_id).unwrap();
260 match &*member_data {
261 MemberData::Uniform(value) => {
262 let llname = format!("uniform_{}", name);
263 let llval = create_global_const(module, llname, value.clone());
264 symbol.set_value(LLSymbolValue::SimplePointer(llval));
265 },
266 MemberData::PerParticle{data, ..} => {
267 let llname = format!("base_addr_{}", name);
268 let llval = create_global_ptr(module, llname, member_definition.get_type(),
269 data.as_ptr() as usize)?;
270 symbol.set_value(LLSymbolValue::ParticleMember{
271 base_ptr: llval,
272 local_ptr: None
273 });
274 }
275 }
276 }
277 FipsSymbolKind::Function(function_id) => {
278 let val = function_index.get(*function_id).unwrap().create_symbol_value(*function_id, context, module)?;
279 symbol.set_value(val);
280 }
281 _ => panic!("Faulty symbol table: global symbols must be either constants or particle members")
282 }
283 }
284
285 let sqrt_func = match &symbol_table.resolve_symbol(BuiltinFunction::Sqrt.get_name())
287 .unwrap().value.as_ref().unwrap()
288 {
289 LLSymbolValue::Function(LLFunctionSymbolValue { function, .. }) => *function,
290 _ => panic!("Corrupted sqrt function."),
291 };
292
293 let interaction_values = neighbor_lists.iter()
295 .filter_map(|(interaction_id, neighbor_list)| {
296 let interaction = thread_context.executor_context.global_context.runtime.interaction_index.get(*interaction_id).unwrap();
298 let interaction_name = interaction.get_name();
299
300 let target_names_a = interaction.iter().map(|(_, quantity_def)| quantity_def.get_target_a())
301 .collect::<Vec<_>>();
302 let target_names_b = interaction.iter().map(|(_, quantity_def)| quantity_def.get_target_b())
303 .collect::<Vec<_>>();
304 let (type_a, type_a_def) = particle_index.get_particle_by_name(interaction.get_type_a()).unwrap();
305 let (type_b, type_b_def) = particle_index.get_particle_by_name(interaction.get_type_b()).unwrap();
306 let is_a = particle_id == type_a; if particle_id != type_a && particle_id != type_b {
310 return None;
311 }
312
313 let namespace_a = interaction.get_name_a();
315 let namespace_b = interaction.get_name_b();
316 let members_a = type_a_def.get_members().map(|(_, member_def)| member_def.get_name())
318 .filter(|member_name| !target_names_a.contains(member_name)) .collect::<Vec<_>>();
320 let members_b = type_b_def.get_members().map(|(_, member_def)| member_def.get_name())
321 .filter(|member_name| !target_names_b.contains(member_name))
322 .collect::<Vec<_>>();
323 let quantity_members_a = type_a_def.get_members().map(|(_, member_def)| member_def.get_name())
325 .filter(|member_name| target_names_a.contains(member_name)) .collect::<Vec<_>>();
327 let quantity_members_b = type_b_def.get_members().map(|(_, member_def)| member_def.get_name())
328 .filter(|member_name| target_names_b.contains(member_name))
329 .collect::<Vec<_>>();
330 let position_member_a_name = type_a_def.get_position_member().unwrap().1.get_name();
332 let position_member_b_name = type_b_def.get_position_member().unwrap().1.get_name();
333 let member_vals_a = create_neighbor_member_values(module, members_a, neighbor_list, particle_index, particle_store);
335 let member_vals_b = create_neighbor_member_values(module, members_b, neighbor_list, particle_index, particle_store);
336 let quantity_member_vals_a = create_neighbor_member_values(module, quantity_members_a, neighbor_list, particle_index, particle_store);
337 let quantity_member_vals_b = create_neighbor_member_values(module, quantity_members_b, neighbor_list, particle_index, particle_store);
338 let ((own_namespace, own_position_name, own_member_vals, own_quantity_member_vals),
340 (other_namespace, other_position_name, other_member_vals)) =
341 if is_a {
342 ((namespace_a, position_member_a_name, member_vals_a, quantity_member_vals_a),
343 (namespace_b, position_member_b_name, member_vals_b))
344 }
345 else {
346 ((namespace_b, position_member_b_name, member_vals_b, quantity_member_vals_b),
347 (namespace_a, position_member_a_name, member_vals_a))
348 };
349 let vals_to_global_array = |mut llvals: Vec<LLVMValueRef>, name| {
351 let llelem_type = LLVMTypeOf(llvals[0]);
352 let llval = LLVMConstArray(llelem_type, llvals.as_mut_ptr(), llvals.len() as u32);
353 let llarray_type = LLVMTypeOf(llval);
354 let llglobal = LLVMAddGlobal(module, llarray_type, cstring!(name));
355 LLVMSetInitializer(llglobal, llval);
356 LLVMSetGlobalConstant(llglobal, 1);
357 llglobal
358 };
359 let own_members = own_member_vals.into_iter()
360 .map(|(member_name, llvals)| {
361 let name = format!("neigh_{}_own_{}", interaction_name, member_name);
362 (member_name.to_string(), vals_to_global_array(llvals, name))
363 }).collect::<HashMap<_,_>>();
364 let own_quantity_members = own_quantity_member_vals.into_iter()
365 .map(|(member_name, llvals)| {
366 let name = format!("neigh_{}_own_quantity_{}", interaction_name, member_name);
367 (member_name.to_string(), vals_to_global_array(llvals, name))
368 }).collect::<HashMap<_,_>>();
369 let other_members = other_member_vals.into_iter()
370 .map(|(member_name, llvals)| {
371 let name = format!("neigh_{}_other_{}", interaction_name, member_name);
372 (member_name.to_string(), vals_to_global_array(llvals, name))
373 }).collect::<HashMap<_,_>>();
374 let cutoff = unwrap_f64_constant(&interaction.get_cutoff()).unwrap();
377 let cutoff_sqr = cutoff*cutoff;
378 let cutoff_sqr = LLVMConstReal(double_type, cutoff_sqr);
379 let block_size_max = neighbor_list.pos_block_size;
381 let block_size_max = LLVMConstInt(int64_type, block_size_max as u64, 0);
382 let mut own_block_index = None;
384 let mut own_block_length = None;
385 for (i, (block_particle_id, block_particle_range)) in neighbor_list.pos_blocks.iter().enumerate() {
386 if particle_id == *block_particle_id && particle_range == *block_particle_range {
387 own_block_index = Some(i);
388 own_block_length = Some(block_particle_range.len());
389 }
390 }
391 let own_block_index = own_block_index.unwrap(); let own_block_length = own_block_length.unwrap();
393 let _own_block_index = own_block_index; let _own_block_length = own_block_length;
395 let own_block_index = LLVMConstInt(int64_type, own_block_index as u64, 0);
396 let own_block_length = LLVMConstInt(int64_type, own_block_length as u64, 0);
397
398 let name = format!("interaction_{}_func", interaction_name);
416 let interaction_func_type = LLVMFunctionType(void_type,
417 [LLVMPointerType(int64_type, 0), LLVMPointerType(int64_type, 0)].as_mut_ptr(), 2, 0);
418 let interaction_func = LLVMAddFunction(module, cstring!(name), interaction_func_type);
419 LLVMSetLinkage(interaction_func, LLVMLinkage::LLVMLinkerPrivateLinkage);
420 let interaction_func_entry = LLVMAppendBasicBlockInContext(context, interaction_func, cstr!("entry"));
421 LLVMPositionBuilderAtEnd(builder, interaction_func_entry);
422 let neighbor_list_index = LLVMGetParam(interaction_func, 0);
424 let neighbor_list = LLVMGetParam(interaction_func, 1);
425 let outer_index_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("i_ptr"));
427 let inner_index_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("n_ptr"));
428 let current_offset_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("a_ptr"));
429 let next_offset_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("b_ptr"));
430 let other_block_index_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("other_block_index_ptr"));
431 let other_offset_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("other_offset_ptr"));
432 let distance_sqr_ptr = LLVMBuildAlloca(builder, double_type, cstr!("dist_sqr_ptr"));
433 let alloca_members = |(member_name, llglobal), prefix: &str| {
434 let lltype = LLVMGetElementType(LLVMGetElementType(LLVMTypeOf(llglobal)));
435 let lltype = match LLVMGetTypeKind(lltype) {
436 LLVMTypeKind::LLVMPointerTypeKind => LLVMGetElementType(lltype),
437 _ => lltype
438 };
439 let llval = LLVMBuildAlloca(builder, lltype, cstring!(format!("{}_{}_ptr", prefix, member_name)));
440 (member_name, (llglobal, llval))
441 };
442 let extract_local_symbols = |statement_block: &Vec<Statement>| {
443 let mut local_symbols = SymbolTable::new();
444 for statement in statement_block {
445 match statement {
446 Statement::Let(let_stmt) => {
447 let llval = LLSymbolValue::SimplePointer(create_local_ptr(module, builder,
448 let_stmt.name.clone(), &let_stmt.typ).unwrap());
449 local_symbols.add_local_symbol_with_value(let_stmt.name.clone(),
450 let_stmt.typ.clone(), llval).unwrap();
451 },
452 _ => {}
453 }
454 }
455 local_symbols
456 };
457 let own_members = own_members.into_iter()
458 .map(|x| alloca_members(x, "own")).collect::<HashMap<_,_>>();
459 let other_members = other_members.into_iter()
460 .map(|x| alloca_members(x, "other")).collect::<HashMap<_,_>>();
461 let own_quantity_members = own_quantity_members.into_iter()
462 .map(|x| alloca_members(x, "own_quantity")).collect::<HashMap<_,_>>();
463 let common_local_symbols = match interaction.get_common_block() {
464 Some(statement_block) => extract_local_symbols(statement_block),
465 None => SymbolTable::new(),
466 };
467 let mut local_symbols = interaction.iter()
468 .map(|(quantity_id, quantity_def)| {
469 match quantity_def.get_expression() {
470 parser::Expression::Block(block) => {
471 (quantity_id, extract_local_symbols(&block.statements))
472 }
473 _ => (quantity_id, SymbolTable::new())
474 }
475 }).collect::<HashMap<InteractionQuantityID, SymbolTable<_>>>();
476 let mut extra_symbols = SymbolTable::new();
477 let mut distance_ptr = None;
478 if let parser::Identifier::Named(distance_name) = interaction.get_distance_identifier() {
479 let llval = LLVMBuildAlloca(builder, double_type, cstr!("distance"));
480 distance_ptr = Some(llval);
481 extra_symbols.add_local_symbol_with_value(distance_name.clone(), FipsType::Double,
482 LLSymbolValue::SimplePointer(llval)).unwrap();
483 }
484 let mut distance_vec_ptr = None;
485 if let Some(distance_vec_name) = interaction.get_distance_vec() {
486 let lltyp = LLVMVectorType(double_type, domain.get_dim() as u32);
487 let llval = LLVMBuildAlloca(builder, lltyp, cstr!("distance"));
488 distance_vec_ptr = Some(llval);
489 extra_symbols.add_local_symbol_with_value(distance_vec_name.clone(),
490 FipsType::Array{ typ: Box::new(FipsType::Double), length: parser::CompileTimeConstant::Literal(domain.get_dim())},
491 LLSymbolValue::SimplePointer(llval)).unwrap();
492 }
493 LLVMBuildStore(builder, LLVMConstInt(int64_type, 0, 0), outer_index_ptr);
495 LLVMBuildStore(builder, LLVMConstInt(int64_type, 0, 0), current_offset_ptr);
496 for (member_name, (llglobal, _)) in own_quantity_members.iter() {
498 let loaded_global = LLVMBuildLoad(builder, *llglobal, cstr!(""));
500 let llptr = LLVMBuildExtractValue(builder, loaded_global, _own_block_index as u32,
502 cstring!(format!("own_quantity_{}_block_ptr", member_name)));
503 assert!(matches!(LLVMGetTypeKind(LLVMTypeOf(llptr)), LLVMTypeKind::LLVMPointerTypeKind));
505 let elem_size = LLVMSizeOf(LLVMGetElementType(LLVMTypeOf(llptr)));
507 let mem_size = LLVMBuildMul(builder, elem_size, own_block_length, cstr!("mem_size"));
508 let llptr = LLVMBuildBitCast(builder, llptr, LLVMPointerType(int8_type, 0), cstr!(""));
511 LLVMBuildMemSet(builder, llptr, LLVMConstInt(int8_type, 0, 0), mem_size, 1);
512 }
513
514 let block_exit = LLVMAppendBasicBlockInContext(context, interaction_func, cstr!("exit"));
516 let block_outer_loop_body = LLVMAppendBasicBlockInContext(context, interaction_func, cstr!("outer_loop_body"));
517 let block_inner_loop_body = LLVMAppendBasicBlockInContext(context, interaction_func, cstr!("inner_loop_body"));
518 let block_outer_loop_body_exit = LLVMAppendBasicBlockInContext(context, interaction_func, cstr!("outer_loop_body_exit"));
519 let block_outer_loop_increment = build_loop(context, builder, block_outer_loop_body,
520 block_exit, outer_index_ptr, own_block_length);
521
522 LLVMPositionBuilderAtEnd(builder, block_outer_loop_body_exit);
524 for (member_name, (llglobal, lllocal)) in own_quantity_members.iter() {
526 let loaded_global = LLVMBuildLoad(builder, *llglobal, cstr!(""));
528 let llptr = LLVMBuildExtractValue(builder, loaded_global, _own_block_index as u32,
530 cstring!(format!("own_quantity_{}_block_ptr", member_name)));
531
532 let outer_index = LLVMBuildLoad(builder, outer_index_ptr, cstr!("i"));
534 let llptr_writeback = LLVMBuildGEP(builder, llptr, [outer_index].as_mut_ptr(), 1,
535 cstring!(format!("writeback_ptr_{}", member_name)));
536
537 let llglobalval = LLVMBuildLoad(builder, llptr_writeback, cstring!(format!("global_val_{}", member_name)));
539 let llacc = LLVMBuildLoad(builder, *lllocal, cstring!(format!("acc_{}", member_name)));
540 let llval = convert_to_scalar_or_array(context, builder,
541 evaluate_binop(context, builder, llglobalval, llacc, parser::BinaryOperator::Add).unwrap());
542
543 LLVMBuildStore(builder, llval, llptr_writeback);
545 }
546 let next_offset = LLVMBuildLoad(builder, next_offset_ptr, cstr!("next_offset"));
547 LLVMBuildStore(builder, next_offset, current_offset_ptr);
548 LLVMBuildBr(builder, block_outer_loop_increment);
549
550 LLVMPositionBuilderAtEnd(builder, block_outer_loop_body);
551 let outer_index = LLVMBuildLoad(builder, outer_index_ptr, cstr!("i"));
552 for (_, (_, lllocal)) in own_quantity_members.iter() {
554 let elem_size = LLVMSizeOf(LLVMGetElementType(LLVMTypeOf(*lllocal)));
556 let llptr = LLVMBuildBitCast(builder, *lllocal, LLVMPointerType(int8_type, 0), cstr!(""));
559 LLVMBuildMemSet(builder, llptr, LLVMConstInt(int8_type, 0, 0), elem_size, 1);
560 }
561 let load_members = |(member_name, (llglobal, lllocal)): (String, (LLVMValueRef, LLVMValueRef)),
563 block_index: LLVMValueRef, particle_index: LLVMValueRef, infix: &str|
564 {
565 let loaded_global = LLVMBuildLoad(builder, llglobal, cstr!(""));
567 let ptrtype = LLVMPointerType(LLVMGetElementType(LLVMTypeOf(loaded_global)), 0);
569 let global_ptr = LLVMBuildBitCast(builder, llglobal, ptrtype, cstr!("global_ptr"));
570 let llptr = LLVMBuildGEP(builder, global_ptr, [block_index].as_mut_ptr(), 1, cstr!(""));
572 let llval = LLVMBuildLoad(builder, llptr, cstring!(format!("loaded_{}_{}", infix, member_name)));
573 let llval = match LLVMGetTypeKind(LLVMTypeOf(llval)) {
575 LLVMTypeKind::LLVMPointerTypeKind => {
576 let llptr = LLVMBuildGEP(builder, llval, [particle_index].as_mut_ptr(), 1, cstr!(""));
577 LLVMBuildLoad(builder, llptr, cstring!(format!("really_loaded_{}_{}", infix, member_name)))
578 },
579 _ => llval
580 };
581 LLVMBuildStore(builder, llval, lllocal);
582 (member_name, lllocal)
583 };
584 let own_members_loaded = own_members.into_iter()
585 .map(|x| load_members(x, own_block_index, outer_index, "own")).collect::<HashMap<_,_>>();
586 let llptr = LLVMBuildGEP(builder, neighbor_list_index, [outer_index].as_mut_ptr(), 1, cstr!(""));
588 let next_offset = LLVMBuildLoad(builder, llptr, cstr!("next_offset"));
589 LLVMBuildStore(builder, next_offset, next_offset_ptr);
590 let current_offset = LLVMBuildLoad(builder, current_offset_ptr, cstr!("previous_offset"));
592 LLVMBuildStore(builder, current_offset, inner_index_ptr);
593 let block_inner_loop_increment = build_loop(context, builder, block_inner_loop_body,
594 block_outer_loop_body_exit, inner_index_ptr, next_offset_ptr);
595 LLVMPositionBuilderAtEnd(builder, block_inner_loop_body);
597 let inner_index = LLVMBuildLoad(builder, inner_index_ptr, cstr!("n"));
598 let llptr = LLVMBuildGEP(builder, neighbor_list, [inner_index].as_mut_ptr(), 1, cstr!("j_ptr"));
600 let other_particle_index = LLVMBuildLoad(builder, llptr, cstr!("j"));
601 let other_block_index = LLVMBuildUDiv(builder, other_particle_index, block_size_max, cstr!("other_block"));
602 LLVMBuildStore(builder, other_block_index, other_block_index_ptr);
603 let other_offset = LLVMBuildURem(builder, other_particle_index, block_size_max, cstr!("other_offset"));
604 LLVMBuildStore(builder, other_offset, other_offset_ptr);
605 let other_members_loaded = other_members.into_iter()
607 .map(|x| load_members(x, other_block_index, other_offset, "other")).collect::<HashMap<_,_>>();
608 let own_position = LLVMBuildLoad(builder, *own_members_loaded.get(own_position_name).unwrap(), cstr!("own_position"));
610 let other_position = LLVMBuildLoad(builder, *other_members_loaded.get(other_position_name).unwrap(), cstr!("other_position"));
611 let cutoff_skin = cutoff * thread_context.executor_context
613 .global_context.runtime.enabled_interactions.get(interaction_id).unwrap().skin_factor;
614 let other_position = correct_postion_vector(context, builder, own_position, other_position,
615 cutoff_skin, domain);
616 let (distance_sqr, distance_vec) = if is_a {
618 calculate_distance_sqr_and_vec(context, builder, own_position, other_position)
619 }
620 else {
621 calculate_distance_sqr_and_vec(context, builder, other_position, own_position)
622 };
623 LLVMBuildStore(builder, distance_sqr, distance_sqr_ptr);
624 if let Some(distance_vec_ptr) = distance_vec_ptr {
625 LLVMBuildStore(builder, distance_vec, distance_vec_ptr);
626 }
627 let within_cutoff = LLVMBuildFCmp(builder, LLVMRealPredicate::LLVMRealOLT,
629 distance_sqr, cutoff_sqr, cstr!("within_cutoff"));
630 let block_process_interaction = LLVMInsertBasicBlockInContext(context, block_inner_loop_increment, cstr!("process_interaction"));
631 LLVMBuildCondBr(builder, within_cutoff, block_process_interaction, block_inner_loop_increment);
632 LLVMPositionBuilderAtEnd(builder, block_process_interaction);
634
635 let particle_symbols = symbol_table.pop_table().unwrap();
637 if let Some(distance_ptr) = distance_ptr {
639 let distance_sqr = LLVMBuildLoad(builder, distance_sqr_ptr, cstr!("dist_sqr"));
640 let distance = LLVMBuildCall(builder, sqrt_func, [distance_sqr].as_mut_ptr(), 1, cstr!("dist"));
641 LLVMBuildStore(builder, distance, distance_ptr);
642 }
643 symbol_table.push_table(extra_symbols);
644 symbol_table.push_table(common_local_symbols);
645 let mut namespace_symbols = HashMap::new();
647 if let parser::Identifier::Named(own_namespace) = own_namespace {
648 namespace_symbols.insert(own_namespace, own_members_loaded);
649 }
650 if let parser::Identifier::Named(other_namespace) = other_namespace {
651 namespace_symbols.insert(other_namespace, other_members_loaded);
652 }
653 let handle_statement = |statement: &Statement, symbol_table: &mut SymbolTable<_>| {
655 match statement {
656 Statement::Let(_) | Statement::Assign(_) => {
657 let (target_name, expression) = match statement{
658 Statement::Let(statement) => (&statement.name, &statement.initial),
659 Statement::Assign(statement) => {
660 if let Some(_) = statement.index {
661 unimplemented!("Indexed assignment in quantities not supported yet");
662 }
663 (&statement.assignee, &statement.value)
664 },
665 _ => unreachable!()
666 };
667 let target = symbol_table.resolve_symbol(target_name)
668 .expect(&format!("Unresolved identifier {}", target_name));
669 let target = match target.kind {
670 FipsSymbolKind::Constant(_) => { panic!("Cannot assign to constant.") }
671 FipsSymbolKind::Function(_) => { panic!("Cannot assign to function.") }
672 FipsSymbolKind::ParticleMember(_) => { panic!("Cannot assign to particle member during interaction.") }
673 FipsSymbolKind::LocalVariable(_) => {
674 match target.value.as_ref().unwrap() {
677 LLSymbolValue::SimplePointer(ptr) => { *ptr }
678 LLSymbolValue::ParticleMember { .. } | LLSymbolValue::Function { .. } => {
679 panic!("Malformed symbol table in interaction!")
680 }
681 }
682 }
683 };
684 let value = evaluate_expression(context, builder, expression,
685 &symbol_table, &namespace_symbols, function_index, callback_target_ptrptr)
686 .expect("Cannot evaluate interaction expression");
687 LLVMBuildStore(builder, convert_to_scalar_or_array(context, builder, value), target);
688 }
689 Statement::Update(_) => { panic!("Update statements are not allowed in interactions!") }
690 Statement::Call(_) => { panic!("Call statements are not allowed in interactions!") }
691 }
692 };
693
694 if let Some(statement_block) = interaction.get_common_block() {
696 for statement in statement_block {
697 handle_statement(statement, &mut symbol_table);
698 }
699 }
700 for (quantity_id, quantity_def) in interaction.iter() {
702 symbol_table.push_table(local_symbols.remove(&quantity_id).unwrap());
703 match quantity_def.get_expression() {
704 parser::Expression::Block(block) => {
705 for statement in &block.statements {
706 handle_statement(statement, &mut symbol_table);
707 }
708 let quantity_value = evaluate_expression(context, builder, &block.expression,
710 &symbol_table, &namespace_symbols, function_index, callback_target_ptrptr)
711 .expect("Cannot evaluate interaction expression");
712
713 let value = match quantity_def.get_symmetry() {
715 parser::InteractionSymmetry::Symmetric => { quantity_value },
717 parser::InteractionSymmetry::Antisymmetric => {
719 if !is_a {
720 llmultiply_by_minus_one(context, builder, quantity_value)
721 }
722 else {
723 quantity_value
724 }
725 },
726 parser::InteractionSymmetry::Asymmetric => {
728 let quantity_value = convert_to_scalar_or_array(context, builder, quantity_value);
729 let lltyp = LLVMTypeOf(quantity_value);
731 assert!(matches!(LLVMGetTypeKind(lltyp), LLVMTypeKind::LLVMArrayTypeKind));
732 assert_eq!(LLVMGetArrayLength(lltyp), 2);
733 let idx = if is_a {0} else {1};
735 LLVMBuildExtractValue(builder, quantity_value, idx, cstr!("quantity_part_own"))
736 },
737 };
738 assert!(matches!(quantity_def.get_reduction_method(), parser::ReductionMethod::Sum));
741 let target_name = if is_a { quantity_def.get_target_a() } else { quantity_def.get_target_b() };
742 let (llglobal, lllocal) = own_quantity_members.get(target_name).unwrap();
743 let accval = LLVMBuildLoad(builder, *lllocal, cstring!(format!("acc_{}", quantity_def.get_name())));
745 let writeback_value = convert_to_scalar_or_array(context, builder,
746 evaluate_binop(context, builder, accval, value, parser::BinaryOperator::Add).unwrap());
747 if LLVMTypeOf(writeback_value) != LLVMGetElementType(LLVMTypeOf(*lllocal)) {
748 panic!("Mismatched type in return expression for interaction quantity {}", quantity_def.get_name());
749 }
750 LLVMBuildStore(builder, writeback_value, *lllocal);
751
752 let block_particle_store = LLVMInsertBasicBlockInContext(context, block_inner_loop_increment,
755 cstring!(format!("partner_store_{}", quantity_def.get_name())));
756 let block_next_quantity = LLVMInsertBasicBlockInContext(context, block_inner_loop_increment,
757 cstring!(format!("quantity_after_{}", quantity_def.get_name())));
758 let other_block_index = LLVMBuildLoad(builder, other_block_index_ptr, cstr!("other_block_index"));
759 let is_same_block = LLVMBuildICmp(builder, LLVMIntPredicate::LLVMIntEQ,
760 own_block_index, other_block_index, cstr!("is_same_block"));
761 LLVMBuildCondBr(builder, is_same_block, block_particle_store, block_next_quantity);
762 LLVMPositionBuilderAtEnd(builder, block_particle_store);
764
765 let value = match quantity_def.get_symmetry() {
766 parser::InteractionSymmetry::Symmetric => { quantity_value },
768 parser::InteractionSymmetry::Antisymmetric => {
770 if is_a {
771 llmultiply_by_minus_one(context, builder, quantity_value)
772 }
773 else {
774 quantity_value
775 }
776 },
777 parser::InteractionSymmetry::Asymmetric => {
779 let quantity_value = convert_to_scalar_or_array(context, builder, quantity_value);
780 let idx = if is_a {1} else {0};
783 LLVMBuildExtractValue(builder, quantity_value, idx, cstr!("quantity_part_own"))
784 }
785 };
786 let loaded_global = LLVMBuildLoad(builder, *llglobal, cstr!(""));
788 let llptr = LLVMBuildExtractValue(builder, loaded_global, _own_block_index as u32,
790 cstring!(format!("own_quantity_{}_block_ptr", target_name)));
791 let other_offset = LLVMBuildLoad(builder, other_offset_ptr, cstr!("other_offset"));
793 let llptr_writeback = LLVMBuildGEP(builder, llptr, [other_offset].as_mut_ptr(), 1,
794 cstring!(format!("writeback_ptr_{}", target_name)));
795 let llglobalacc = LLVMBuildLoad(builder, llptr_writeback, cstring!(format!("global_val_{}", target_name)));
797 let llval = convert_to_scalar_or_array(context, builder,
798 evaluate_binop(context, builder, llglobalacc, value, parser::BinaryOperator::Add).unwrap());
799 LLVMBuildStore(builder, llval, llptr_writeback);
800
801 LLVMBuildBr(builder, block_next_quantity);
802 LLVMPositionBuilderAtEnd(builder, block_next_quantity);
803
804 }
837 _ => todo!()
838 }
839 symbol_table.pop_table(); }
841 symbol_table.pop_table(); symbol_table.pop_table(); LLVMBuildBr(builder, block_inner_loop_increment);
846
847 symbol_table.push_table(particle_symbols);
849
850 LLVMPositionBuilderAtEnd(builder, block_exit);
851 LLVMBuildRetVoid(builder);
852
853 let code = InteractionValues {
854 own_pos_block_index: own_block_index,
855 interaction_func
856 };
857 Some((interaction_id, code))
858 })
859 .collect::<HashMap<_,_>>();
860
861 let worker_main_type = LLVMFunctionType(void_type, std::ptr::null_mut(), 0, 0);
863 let main_function = LLVMAddFunction(module, cstring!(WORKER_MAIN_NAME), worker_main_type);
864 let main_entry = LLVMAppendBasicBlockInContext(context, main_function, cstr!("entry"));
865
866 LLVMPositionBuilderAtEnd(builder, main_entry);
868
869 let neighbor_list_index_var = LLVMBuildAlloca(builder,
871 LLVMPointerType(int64_type, 0), cstr!("neighbor_list_index_var"));
872 let neighbor_list_var = LLVMBuildAlloca(builder,
873 LLVMPointerType(int64_type, 0), cstr!("neighbor_list_var"));
874
875 for node in &timeline.nodes {
876 match node {
877 SimulationNode::StatementBlock(node) => {
878 let node_func = LLVMAddFunction(module, cstr!("node_func"), worker_main_type);
880 LLVMSetLinkage(node_func, LLVMLinkage::LLVMLinkerPrivateLinkage);
881 let block_node_main_entry = LLVMAppendBasicBlockInContext(context, node_func, cstr!("entry"));
882
883 LLVMPositionBuilderAtEnd(builder, block_node_main_entry);
885
886 let mut local_symbols: SymbolTable<LLSymbolValue> = node.local_symbols.clone().convert();
888 for (name, symbol) in local_symbols.iter_mut() {
889 match &symbol.kind {
890 FipsSymbolKind::LocalVariable(typ) => {
891 match typ {
892 FipsType::Double => {
893 let typ = &typ.clone();
894 symbol.set_value(LLSymbolValue::SimplePointer(
895 create_local_ptr(module, builder, name.clone(), typ)?
896 ));
897 }
898 _ => todo!() }
900 },
901 _ => panic!("Faulty symbol table: found non-local-variable symbol in local symbols")
902 }
903 }
904 symbol_table.push_table(local_symbols);
906
907 for (name, symbol) in symbol_table.iter_mut() {
909 match &symbol.kind {
910 FipsSymbolKind::ParticleMember(member_id) => {
911 let name = format!("current_{}", name);
912 let member_definition = particle_index.get(particle_id).unwrap()
913 .get_member(&member_id).unwrap();
914 match symbol.value.as_mut().unwrap() {
915 LLSymbolValue::ParticleMember { local_ptr, .. } => {
916 *local_ptr = Some(create_local_ptr(module, builder, name, &member_definition.get_type())?)
917 }
918 LLSymbolValue::SimplePointer(_) | LLSymbolValue::Function { .. } => {}
920 };
921 }
922 FipsSymbolKind::Constant(_) => {}
923 FipsSymbolKind::LocalVariable(_) => {}
924 FipsSymbolKind::Function(_) => {},
925 }
926 }
927
928 let loop_index_ptr = LLVMBuildAlloca(builder, int64_type, cstr!("loop_var"));
930
931 LLVMBuildStore(builder, start_index, loop_index_ptr);
935
936 let block_loop_check = LLVMAppendBasicBlockInContext(context, node_func, cstr!("loop_check"));
938 let block_loop_body = LLVMAppendBasicBlockInContext(context, node_func, cstr!("loop_body"));
939 let block_loop_increment = LLVMAppendBasicBlockInContext(context, node_func, cstr!("loop_increment"));
940 let block_after_loop = LLVMAppendBasicBlockInContext(context, node_func, cstr!("after_loop"));
941
942 LLVMBuildBr(builder, block_loop_check);
944 LLVMPositionBuilderAtEnd(builder, block_loop_check);
945 let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
946 let comparison = LLVMBuildICmp(builder, LLVMIntPredicate::LLVMIntULT, loop_index, end_index, cstr!("loop_check"));
947 LLVMBuildCondBr(builder, comparison, block_loop_body, block_after_loop);
948
949 LLVMPositionBuilderAtEnd(builder, block_loop_increment);
951 let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
952 let llone = LLVMConstInt(int64_type, 1, 0);
953 let incremented_index = LLVMBuildAdd(builder, loop_index, llone, cstr!("incremented_val"));
954 LLVMBuildStore(builder, incremented_index, loop_index_ptr);
955 LLVMBuildBr(builder, block_loop_check);
956
957 LLVMPositionBuilderAtEnd(builder, block_loop_body);
959
960 let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
962 for (name, symbol) in symbol_table.iter() {
963 match &symbol.kind {
964 FipsSymbolKind::ParticleMember(_) => {
965 match symbol.value.as_ref().unwrap() {
966 LLSymbolValue::ParticleMember { base_ptr, local_ptr, .. } => {
967 let llname = format!("base_addr_loaded_{}", &name);
968 let base_ptr = LLVMBuildLoad(builder, *base_ptr, cstring!(llname));
969 let llname = format!("current_ptr_{}", &name);
970 let current_ptr = LLVMBuildGEP(builder, base_ptr, [loop_index].as_mut_ptr(), 1, cstring!(llname));
971
972 let llname = format!("loaded_{}", &name);
973 let llval = LLVMBuildLoad(builder, current_ptr, cstring!(llname));
974 LLVMBuildStore(builder, llval, local_ptr.unwrap());
975 }
976 LLSymbolValue::SimplePointer(_) | LLSymbolValue::Function { .. } => {}
978 };
979 }
980 FipsSymbolKind::Constant(_) => {}
981 FipsSymbolKind::LocalVariable(_) => {}
982 FipsSymbolKind::Function(_) => {},
983 }
984 }
985
986 let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
987
988 let mut members_changed = HashSet::new(); for statement in &node.statements {
991 match statement {
992 Statement::Let(_) | Statement::Assign(_) => {
994 let (target_name, expression) = match statement{
995 Statement::Let(statement) => (&statement.name, &statement.initial),
996 Statement::Assign(statement) => (&statement.assignee, &statement.value),
997 _ => unreachable!()
998 };
999 let target = symbol_table.resolve_symbol(target_name)
1000 .ok_or(anyhow!("Unresolved identifier {}", target_name))?;
1001 let target = match target.kind {
1002 FipsSymbolKind::Constant(_) => { return Err(anyhow!("Cannot assign to constant.")) }
1003 FipsSymbolKind::Function(_) => { return Err(anyhow!("Cannot assign to function.")) }
1004 FipsSymbolKind::LocalVariable(_) | FipsSymbolKind::ParticleMember(_) => {
1005 match target.kind {
1007 FipsSymbolKind::ParticleMember(member_id) => {
1008 members_changed.insert(member_id);
1009 }
1010 _ => {}
1011 }
1012 match target.value.as_ref().unwrap() {
1015 LLSymbolValue::SimplePointer(ptr) => { *ptr }
1016 LLSymbolValue::ParticleMember { local_ptr, .. } => {
1017 local_ptr.unwrap()
1018 }
1019 LLSymbolValue::Function { .. } => panic!("Target of let or assign statement cannot be a function!")
1020 }
1021 }
1022 };
1023 let value = evaluate_expression(context, builder, expression,
1024 &symbol_table, &HashMap::new(), function_index, callback_target_ptrptr)?;
1025 match statement {
1026 Statement::Let(_) | Statement::Assign(AssignStatement {index: None, ..})
1028 => {
1029 LLVMBuildStore(builder, convert_to_scalar_or_array(context, builder, value), target);
1030 },
1031 Statement::Assign(AssignStatement {index: Some(index), ..}) => {
1033 let lltyp = LLVMGetElementType(LLVMTypeOf(target));
1036 match LLVMGetTypeKind(lltyp) {
1037 LLVMTypeKind::LLVMArrayTypeKind => {
1038 let elemtyp = LLVMGetElementType(lltyp);
1039 match LLVMGetTypeKind(elemtyp) {
1040 LLVMTypeKind::LLVMDoubleTypeKind | LLVMTypeKind::LLVMIntegerTypeKind => {}
1041 _ => unimplemented!("Multidimensional assignment not supported")
1042 }
1043 },
1044 _ => panic!("Trying to index non-array type (identifier {})!", target_name)
1045 }
1046
1047 let name = format!("old_{}_assign", target_name);
1048 let llval = LLVMBuildLoad(builder, target, cstring!(name));
1049 let name = format!("new_{}_assign", target_name);
1050 let llval = LLVMBuildInsertValue(builder, llval, value,
1051 unwrap_usize_constant(index)? as u32, cstring!(name));
1052 LLVMBuildStore(builder, llval, target);
1053 },
1054 _ => unreachable!(),
1055 }
1056
1057 }
1058 Statement::Update(_) | Statement::Call(_)
1059 => panic!("Update and call statements not eliminated in simgraph construction")
1060 }
1061 }
1062
1063 for (name, symbol) in symbol_table.iter() {
1065 match &symbol.kind {
1066 FipsSymbolKind::ParticleMember(member_id) => {
1067 if members_changed.contains(member_id) {
1070 match symbol.value.as_ref().unwrap() {
1071 LLSymbolValue::ParticleMember { base_ptr, local_ptr, .. } => {
1072 let llname = format!("base_addr_loaded_{}", &name);
1073 let base_ptr = LLVMBuildLoad(builder, *base_ptr, cstring!(llname));
1074 let llname = format!("current_ptr_{}", &name);
1075 let current_ptr = LLVMBuildGEP(builder, base_ptr, [loop_index].as_mut_ptr(), 1, cstring!(llname));
1076 let llname = format!("final_{}", &name);
1077 let mut llval = LLVMBuildLoad(builder, local_ptr.unwrap(), cstring!(llname));
1078 if particle_index.get(particle_id).unwrap()
1081 .get_member(member_id).unwrap()
1082 .is_position()
1083 {
1084 match domain {
1085 Domain::Dim2{x,y} => {
1086 assert!(matches!(x.oob, OutOfBoundsBehavior::Periodic));
1087 assert!(matches!(y.oob, OutOfBoundsBehavior::Periodic));
1088 }
1089 Domain::Dim3{x,y,z} => {
1090 assert!(matches!(x.oob, OutOfBoundsBehavior::Periodic));
1091 assert!(matches!(y.oob, OutOfBoundsBehavior::Periodic));
1092 assert!(matches!(z.oob, OutOfBoundsBehavior::Periodic));
1093 }
1094 }
1095 let lldomain_lo = match domain {
1096 Domain::Dim2{x,y} => {
1097 LLVMConstVector([
1098 LLVMConstReal(double_type, x.low),
1099 LLVMConstReal(double_type, y.low),
1100 ].as_mut_ptr(), 2)
1101 }
1102 Domain::Dim3{x,y,z} => {
1103 LLVMConstVector([
1104 LLVMConstReal(double_type, x.low),
1105 LLVMConstReal(double_type, y.low),
1106 LLVMConstReal(double_type, z.low),
1107 ].as_mut_ptr(), 3)
1108 }
1109 };
1110 let lldomain_size = match domain {
1111 Domain::Dim2{x,y} => {
1112 LLVMConstVector([
1113 LLVMConstReal(double_type, x.size()),
1114 LLVMConstReal(double_type, y.size()),
1115 ].as_mut_ptr(), 2)
1116 }
1117 Domain::Dim3{x,y,z} => {
1118 LLVMConstVector([
1119 LLVMConstReal(double_type, x.size()),
1120 LLVMConstReal(double_type, y.size()),
1121 LLVMConstReal(double_type, z.size()),
1122 ].as_mut_ptr(), 3)
1123 }
1124 };
1125 llval = evaluate_binop(context, builder, llval,
1126 lldomain_lo, parser::BinaryOperator::Sub)?;
1127 llval = LLVMBuildFRem(builder, llval, lldomain_size, cstr!(""));
1128 llval = LLVMBuildFAdd(builder, llval, lldomain_size, cstr!(""));
1129 llval = LLVMBuildFRem(builder, llval, lldomain_size, cstr!(""));
1130 llval = LLVMBuildFAdd(builder, llval, lldomain_lo, cstr!(""));
1131 llval = convert_to_scalar_or_array(context, builder, llval);
1132 }
1133
1134 LLVMBuildStore(builder, llval, current_ptr);
1135 }
1136 LLSymbolValue::SimplePointer(_) => {}
1138 LLSymbolValue::Function { .. } => panic!("Particle member symbol has function value")
1139 };
1140 }
1141 }
1142 FipsSymbolKind::Constant(_) => {}
1143 FipsSymbolKind::LocalVariable(_) => {}
1144 FipsSymbolKind::Function(_) => {},
1145 }
1146 }
1147
1148 LLVMBuildBr(builder, block_loop_increment);
1149
1150 LLVMPositionBuilderAtEnd(builder, block_after_loop);
1152 LLVMBuildRetVoid(builder);
1153
1154 symbol_table.pop_table();
1156
1157 LLVMPositionBuilderAtEnd(builder, main_entry);
1158 LLVMBuildCall(builder, node_func, std::ptr::null_mut(), 0, cstr!(""));
1159 }
1160 SimulationNode::CommonBarrier(barrier_id) => {
1161 match &thread_context.executor_context.global_context.simgraph.barriers.get(*barrier_id).unwrap().kind {
1163 BarrierKind::CallBarrier(_) => {
1165 let barrier_data = barrier_id.data().as_ffi();
1166 let llbarrier_data = LLVMConstInt(int64_type, barrier_data, 0);
1167 let callback_target_param = LLVMBuildLoad(builder, callback_target_ptrptr, cstr!("tmp"));
1168 LLVMBuildCall(builder, call2rust_handler, [callback_target_param, llbarrier_data].as_mut_ptr(), 2, cstr!(""));
1169 }
1170 BarrierKind::InteractionBarrier(interaction_id, quantity_id) => {
1171 if quantity_id.is_some() {
1172 unimplemented!();
1173 }
1174 let interaction_vals = interaction_values.get(interaction_id); if let Some(interaction_vals) = interaction_vals {
1176 let barrier_data = barrier_id.data().as_ffi();
1177 let llbarrier_data = LLVMConstInt(int64_type, barrier_data, 0);
1178 let callback_target_param = LLVMBuildLoad(builder, callback_target_ptrptr, cstr!("tmp"));
1179 let block_index = interaction_vals.own_pos_block_index;
1180 LLVMBuildCall(builder, interaction_handler, [callback_target_param, llbarrier_data, block_index,
1182 neighbor_list_index_var, neighbor_list_var].as_mut_ptr(), 5, cstr!(""));
1183 let neighbor_list_index = LLVMBuildLoad(builder, neighbor_list_index_var, cstr!("neighbor_list_index"));
1184 let neighbor_list = LLVMBuildLoad(builder, neighbor_list_var, cstr!("neighbor_list"));
1185 LLVMBuildCall(builder, interaction_vals.interaction_func,
1186 [neighbor_list_index, neighbor_list].as_mut_ptr(), 2, cstr!(""));
1187 LLVMBuildCall(builder, interaction_sync_handler, [callback_target_param,
1189 llbarrier_data].as_mut_ptr(), 2, cstr!(""));
1190 }
1191 else {
1192 let name = thread_context.executor_context.global_context.runtime.interaction_index.get(*interaction_id)
1193 .unwrap().get_name();
1194 println!("Debug: Ignoring interaction barrier for disabled interaction {}", name);
1195 }
1196 }
1197
1198 }
1199
1200 }
1201 }
1202 }
1203
1204 LLVMPositionBuilderAtEnd(builder, main_entry);
1206 let callback_target_param = LLVMBuildLoad(builder, callback_target_ptrptr, cstr!("tmp"));
1207 LLVMBuildCall(builder, end_of_step_handler, [callback_target_param].as_mut_ptr(), 1, cstr!(""));
1208 LLVMBuildRetVoid(builder);
1209
1210 LLVMVerifyModule(module, LLVMVerifierFailureAction::LLVMPrintMessageAction, ptr::null_mut());
1212
1213 let pm_builder = LLVMPassManagerBuilderCreate();
1215 LLVMPassManagerBuilderUseInlinerWithThreshold(pm_builder, 255);
1216 LLVMPassManagerBuilderSetOptLevel(pm_builder, 2);
1217 let module_pass_manager = LLVMCreatePassManager();
1218 LLVMPassManagerBuilderPopulateModulePassManager(pm_builder, module_pass_manager);
1219
1220 LLVMRunPassManager(module_pass_manager, module);
1221
1222 #[cfg(debug_assertions)] {
1223 if thread_context.particle_range.start == 0 {
1224 let module_cstr = LLVMPrintModuleToString(module);
1225 let module_str = std::ffi::CStr::from_ptr(module_cstr).to_str()?;
1226 println!("{}", module_str);
1227 LLVMDisposeMessage(module_cstr);
1228 }
1229 }
1230 LLVMVerifyModule(module, LLVMVerifierFailureAction::LLVMPrintMessageAction, ptr::null_mut());
1231
1232 let module_ts = LLVMOrcCreateNewThreadSafeModule(module, context_ts);
1234 LLVMOrcDisposeThreadSafeContext(context_ts);
1235 std::mem::drop(particle_data);
1239 std::mem::drop(neighbor_lists);
1241 Ok(Self {
1242 module_ts,
1243 callback_target,
1244 external_symbols
1245 })
1246 }
1247 }
1248}
1249
1250pub struct CodeExecutor {
1252 jit: LLVMOrcLLJITRef,
1254 #[allow(dead_code)]
1256 callback_target: Box<CallbackTarget>,
1257 #[allow(dead_code)]
1259 allowed_syms: Box<[LLVMOrcSymbolStringPoolEntryRef]>,
1260 #[allow(dead_code)]
1262 dummy_func_vec: Vec<*const extern "C" fn()>
1263}
1264
1265impl CodeExecutor {
1266 pub(crate) fn new(codegen: CodeGenerator) -> Result<Self> {
1267 unsafe {
1268 LLVMInitializeCore(LLVMGetGlobalPassRegistry());
1270 LLVM_InitializeNativeTarget();
1271 LLVM_InitializeNativeAsmPrinter();
1272 let dummy_func_vec = vec![
1274 _call2rust_handler as _,
1275 _interaction_handler as _,
1276 _interaction_sync_handler as _,
1277 _end_of_step as _,
1278 print_u64 as _,
1279 print_f64 as _,
1280 _random_normal as _
1281 ];
1282 let mut jit = MaybeUninit::uninit();
1286 let error = LLVMOrcCreateLLJIT(jit.as_mut_ptr(), ptr::null_mut());
1287 if !error.is_null() {
1288 return llvm_errorref_to_result("Failed to create LLJIT", error);
1289 };
1290 let jit = jit.assume_init();
1291
1292 let mut allowed_syms = codegen.external_symbols.iter()
1294 .map(|symbol_name| LLVMOrcLLJITMangleAndIntern(jit, cstring!(symbol_name.clone())))
1295 .chain(std::iter::once(ptr::null_mut()))
1296 .collect::<Box<[_]>>();
1297
1298 let mut process_symbols_generator = MaybeUninit::uninit();
1299 LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
1300 process_symbols_generator.as_mut_ptr(), LLVMOrcLLJITGetGlobalPrefix(jit),
1301 Some(allowed_symbol_filter), allowed_syms.as_mut_ptr() as *mut c_void
1302 );
1303 let process_symbols_generator = process_symbols_generator.assume_init();
1304 LLVMOrcJITDylibAddGenerator(LLVMOrcLLJITGetMainJITDylib(jit), process_symbols_generator);
1305
1306 let main_jitdylib = LLVMOrcLLJITGetMainJITDylib(jit);
1308 let error = LLVMOrcLLJITAddLLVMIRModule(jit, main_jitdylib, codegen.module_ts);
1309 if !error.is_null() {
1310 return llvm_errorref_to_result("Failed to add IR module", error);
1311 };
1312
1313 Ok(Self {
1315 jit,
1316 allowed_syms,
1317 dummy_func_vec,
1318 callback_target: codegen.callback_target
1319 })
1320 }
1321 }
1322
1323 pub(crate) fn run(&mut self) {
1324 unsafe {
1325 let mut funcaddr = MaybeUninit::uninit();
1326 let error = LLVMOrcLLJITLookup(self.jit, funcaddr.as_mut_ptr(), cstring!(WORKER_MAIN_NAME));
1327 if !error.is_null() {
1328 panic!("Lookup of worker main function failed!");
1329 };
1330 let func: WorkerMainFunc = std::mem::transmute_copy(&funcaddr);
1331 func();
1332 }
1333 }
1334}