1use std::{collections::HashMap, ffi::{CStr, CString}, sync::RwLockReadGuard, unimplemented};
5
6use anyhow::{anyhow, Result};
7
8use libc::{c_uint, c_ulonglong};
9use llvm_sys::prelude::*;
10use llvm_sys::core::*;
11use llvm_sys::error::*;
12use llvm_sys::{LLVMIntPredicate, LLVMRealPredicate, LLVMTypeKind};
13
14use crate::{codegen::promote_to_vector, parser::{FipsType, CompileTimeConstant, BinaryOperator}, runtime::{MemberData, ParticleIndex, ParticleStore, Domain}};
15use crate::utils::FipsValue;
16
17use super::{NeighborList, evaluate_binop};
18
19pub(crate) unsafe fn llvm_errorref_to_result<T>(context: &str, error: LLVMErrorRef) -> Result<T> {
21 let message = CStr::from_ptr(LLVMGetErrorMessage(error)).to_str()
22 .expect("Error while decoding LLVM error (how ironic)");
23 Err(anyhow!("{}: {}", context, message))
24}
25
26#[macro_export]
28macro_rules! cstr {
29 ($string:expr) => {
30 concat!($string, "\0").as_ptr() as *const _
31 }
32}
33
34#[macro_export]
35macro_rules! cstring {
36 ($string:expr) => {
37 CString::new($string).expect("String conversion failed")
38 .as_bytes().as_ptr() as *const _
39 }
40}
41
42pub(crate) unsafe fn get_llvm_type_dims(context: LLVMContextRef, typ: &FipsType) -> Result<(LLVMTypeRef, Vec<usize>)> {
44 Ok(match typ {
45 FipsType::Double => (LLVMDoubleTypeInContext(context), vec![]),
46 FipsType::Int64 => (LLVMInt64TypeInContext(context), vec![]),
47 FipsType::Array { typ, length } => {
48 let length = match length {
49 CompileTimeConstant::Literal(value) | CompileTimeConstant::Substituted(value, _) => value,
50 CompileTimeConstant::Identifier(name) => return Err(anyhow!("Unresolved identifer {}", name))
51 };
52 let (subtype, mut subdims) = get_llvm_type_dims(context, typ)?;
53 subdims.insert(0, *length);
54 (subtype, subdims)
55 }
56 })
57}
58
59pub(crate) unsafe fn create_global_const_double(module: LLVMModuleRef, name: String, value: f64) -> LLVMValueRef {
60 let context = LLVMGetModuleContext(module);
61 let typ = LLVMDoubleTypeInContext(context);
62 let llval = LLVMAddGlobal(module, typ, cstring!(name));
63 let initializer = LLVMConstReal(typ, value);
64 LLVMSetGlobalConstant(llval, 1);
65 LLVMSetInitializer(llval, initializer);
66 llval
67}
68
69pub(crate) unsafe fn create_global_const_int64(module: LLVMModuleRef, name: String, value: i64) -> LLVMValueRef {
70 let context = LLVMGetModuleContext(module);
71 let typ = LLVMInt64TypeInContext(context);
72 let llval = LLVMAddGlobal(module, typ, cstring!(name));
73 let initializer = LLVMConstInt(typ, value as c_ulonglong, 0); LLVMSetGlobalConstant(llval, 1);
75 LLVMSetInitializer(llval, initializer);
76 llval
77}
78
79pub(crate) unsafe fn create_global_const(module: LLVMModuleRef, name: String, value: FipsValue) -> LLVMValueRef {
80 match value {
81 FipsValue::Int64(value) => create_global_const_int64(module, name, value),
82 FipsValue::Double(value) => create_global_const_double(module, name, value)
83 }
84}
85
86unsafe fn __create_global_ptr(context: LLVMContextRef, module: LLVMModuleRef, name: String, scalar_type: LLVMTypeRef,
87 ptr: usize, stride: u32) -> LLVMValueRef
88{
89 let element_type = match stride {
90 1 => scalar_type,
91 n => LLVMArrayType(scalar_type, n as c_uint)
92 };
93 let typ = LLVMPointerType(element_type, 0);
94 let llval = LLVMAddGlobal(module, typ, cstring!(name));
95 let initializer = LLVMConstIntToPtr(LLVMConstInt(LLVMInt64TypeInContext(context), ptr as c_ulonglong, 0), typ);
97 LLVMSetGlobalConstant(llval, 1);
98 LLVMSetInitializer(llval, initializer);
99 llval
100}
101
102pub(crate) unsafe fn create_global_ptr(module: LLVMModuleRef, name: String, typ: &FipsType, ptr: usize) -> Result<LLVMValueRef> {
103 let context = LLVMGetModuleContext(module);
104 let (lltype, dims) = get_llvm_type_dims(context, typ)?;
105 Ok(match dims.len() {
106 0 => __create_global_ptr(context, module, name, lltype, ptr, 1),
107 1 => __create_global_ptr(context, module, name, lltype, ptr, dims[0] as u32),
108 _ => unimplemented!("No multidim support for now"),
109 })
110}
111
112pub(crate) unsafe fn create_local_ptr(module: LLVMModuleRef, builder: LLVMBuilderRef, name: String, typ: &FipsType) -> Result<LLVMValueRef> {
113 let context = LLVMGetModuleContext(module);
114 let (lltyp, dims) = get_llvm_type_dims(context, typ)?;
115 let typ = match dims.len() {
116 0 => lltyp,
117 1 => LLVMArrayType(lltyp, dims[0] as c_uint),
118 _ => unimplemented!("No multidim support for now"),
119 };
120 Ok(LLVMBuildAlloca(builder, typ, cstring!(name)))
121}
122
123pub(crate) unsafe fn fips_value_2_llvm(module: LLVMModuleRef, value: &FipsValue) -> LLVMValueRef {
125 match value {
126 FipsValue::Int64(value) => {
127 let context = LLVMGetModuleContext(module);
128 let typ = LLVMInt64TypeInContext(context);
129 LLVMConstInt(typ, *value as c_ulonglong, 0)
130 },
131 FipsValue::Double(value) => {
132 let context = LLVMGetModuleContext(module);
133 let typ = LLVMDoubleTypeInContext(context);
134 LLVMConstReal(typ, *value)
135 }
136 }
137}
138
139unsafe fn __fips_ptr_2_llvm(context: LLVMContextRef, scalar_type: LLVMTypeRef,
140 ptr: usize, stride: u32) -> LLVMValueRef
141{
142 let element_type = match stride {
143 1 => scalar_type,
144 n => LLVMArrayType(scalar_type, n as c_uint)
145 };
146 let typ = LLVMPointerType(element_type, 0);
147 LLVMConstIntToPtr(LLVMConstInt(LLVMInt64TypeInContext(context), ptr as c_ulonglong, 0), typ)
149}
150
151pub(crate) unsafe fn fips_ptr_2_llvm(module: LLVMModuleRef, typ: &FipsType, ptr: usize) -> Result<LLVMValueRef> {
152 let context = LLVMGetModuleContext(module);
153 let (lltype, dims) = get_llvm_type_dims(context, typ)?;
154 Ok(match dims.len() {
155 0 => __fips_ptr_2_llvm(context, lltype, ptr, 1),
156 1 => __fips_ptr_2_llvm(context, lltype, ptr, dims[0] as u32),
157 _ => unimplemented!("No multidim support for now"),
158 })
159}
160pub(crate) unsafe fn create_neighbor_member_values<'a>(module: LLVMModuleRef, members: Vec<&'a str>,
161 neighbor_list: &RwLockReadGuard<NeighborList>, particle_index: &ParticleIndex, particle_store: &ParticleStore)
162-> HashMap<&'a str, Vec<LLVMValueRef>> {
163 members.iter().map(|member_name| {
164 let llvals = neighbor_list.pos_blocks.iter()
165 .map(move |(particle_id, index_range)| {
166 let particle_def = particle_index.get(*particle_id).unwrap();
167 let particle_data = particle_store.get_particle(*particle_id).unwrap();
168 let (member_id, member_def) = match particle_def.get_member_by_name(member_name) {
169 None => { return None }
170 Some(x) => x
171 };
172 let member_data = particle_data.borrow_member(&member_id).unwrap();
173 Some(match &*member_data {
174 MemberData::Uniform(value) => {
175 fips_value_2_llvm(module, value)
176 }
177 MemberData::PerParticle{data, ..} => {
178 let offset = index_range.start * member_def.get_member_size().unwrap();
180 let data_ptr = data.as_ptr() as usize + offset;
181 fips_ptr_2_llvm(module, member_def.get_type(), data_ptr).unwrap()
182 }
183 })
184 }).collect::<Vec<_>>();
185 let mut lltyp_check = None;
187 for llval in &llvals {
188 if let Some(llval) = llval {
189 lltyp_check = Some(LLVMTypeOf(*llval));
190 break;
191 }
192 }
193 let lltyp_check = lltyp_check.unwrap();
194 let llvals = llvals.into_iter().map(|llval| llval.unwrap()).collect::<Vec<_>>();
198 for llval in &llvals {
199 assert_eq!(lltyp_check, LLVMTypeOf(*llval));
200 }
201 (*member_name, llvals)
202 }).collect::<HashMap<_,_>>()
203}
204
205pub(crate) unsafe fn build_loop(context: LLVMContextRef, builder: LLVMBuilderRef, block_loop_body: LLVMBasicBlockRef,
206 block_after_loop: LLVMBasicBlockRef, loop_index_ptr: LLVMValueRef, end_index: LLVMValueRef) -> LLVMBasicBlockRef {
207 let block_loop_check = LLVMInsertBasicBlockInContext(context, block_after_loop, cstr!("loop_check"));
209 let block_loop_increment = LLVMInsertBasicBlockInContext(context, block_after_loop, cstr!("loop_increment"));
211
212 LLVMBuildBr(builder, block_loop_check);
214 LLVMPositionBuilderAtEnd(builder, block_loop_check);
215 let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
216 let end_index = match LLVMGetTypeKind(LLVMTypeOf(end_index)) {
217 LLVMTypeKind::LLVMPointerTypeKind => { LLVMBuildLoad(builder, end_index, cstr!("end_index")) }
218 _ => end_index
219 };
220 let comparison = LLVMBuildICmp(builder, LLVMIntPredicate::LLVMIntULT, loop_index, end_index, cstr!("loop_check"));
221 LLVMBuildCondBr(builder, comparison, block_loop_body, block_after_loop);
222
223 LLVMPositionBuilderAtEnd(builder, block_loop_increment);
225 let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
226 let llone = LLVMConstInt(LLVMInt64TypeInContext(context), 1, 0);
227 let incremented_index = LLVMBuildAdd(builder, loop_index, llone, cstr!("incremented_val"));
228 LLVMBuildStore(builder, incremented_index, loop_index_ptr);
229 LLVMBuildBr(builder, block_loop_check);
230
231 block_loop_increment
232}
233
234pub(crate) unsafe fn llmultiply_by_minus_one(context: LLVMContextRef, builder: LLVMBuilderRef, value: LLVMValueRef) -> LLVMValueRef {
235 match LLVMGetTypeKind(LLVMTypeOf(value)) {
236 LLVMTypeKind::LLVMArrayTypeKind => {
237 let length = LLVMGetArrayLength(LLVMTypeOf(value));
238 llmultiply_by_minus_one(context, builder, promote_to_vector(context, builder, value, length))
239 }
240 LLVMTypeKind::LLVMVectorTypeKind => {
241 let scalar = scalar_minus_one(context, LLVMGetElementType(LLVMTypeOf(value)));
242 let length = LLVMGetVectorSize(LLVMTypeOf(value));
243 let llminus = LLVMConstVector(vec![scalar; length as usize].as_mut_ptr(), length);
244 match LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(value))) {
245 LLVMTypeKind::LLVMDoubleTypeKind => {
246 LLVMBuildFMul(builder, llminus, value, cstr!("negated"))
247 }
248 LLVMTypeKind::LLVMIntegerTypeKind => {
249 LLVMBuildMul(builder, llminus, value, cstr!("negated"))
250 }
251 _ => panic!()
252 }
253 }
254 LLVMTypeKind::LLVMDoubleTypeKind => {
255 let llminus = scalar_minus_one(context, LLVMTypeOf(value));
256 LLVMBuildFMul(builder, llminus, value, cstr!("negated"))
257 }
258 LLVMTypeKind::LLVMIntegerTypeKind => {
259 let llminus = scalar_minus_one(context, LLVMTypeOf(value));
260 LLVMBuildMul(builder, llminus, value, cstr!("negated"))
261 }
262 _ => panic!()
263 }
264}
265
266unsafe fn scalar_minus_one(context: LLVMContextRef, typ: LLVMTypeRef) -> LLVMValueRef {
267 let double_type = LLVMDoubleTypeInContext(context);
268 let int64_type = LLVMInt64TypeInContext(context);
269 match LLVMGetTypeKind(typ) {
270 LLVMTypeKind::LLVMDoubleTypeKind => {
271 LLVMConstReal(double_type, -1.0)
272 }
273 LLVMTypeKind::LLVMIntegerTypeKind => {
274 LLVMConstInt(int64_type, std::mem::transmute(-1i64), 1)
275 }
276 _ => panic!()
277 }
278}
279
280pub(crate) unsafe fn calculate_distance_sqr_and_vec(context: LLVMContextRef, builder: LLVMBuilderRef,
281 pos_1: LLVMValueRef, pos_2: LLVMValueRef)
282-> (LLVMValueRef, LLVMValueRef) {
283 let dist_vec = evaluate_binop(context, builder, pos_2, pos_1, BinaryOperator::Sub).unwrap();
285 let dist_vec_sqr = evaluate_binop(context, builder, dist_vec, dist_vec, BinaryOperator::Mul).unwrap();
287 let mut dist = LLVMBuildExtractElement(builder, dist_vec_sqr,
289 LLVMConstInt(LLVMInt64TypeInContext(context), 0, 0), cstr!("dist_acc"));
290 for i in 1..LLVMGetVectorSize(LLVMTypeOf(dist_vec_sqr)) {
291 dist = LLVMBuildFAdd(builder, dist,
292 LLVMBuildExtractElement(builder, dist_vec_sqr,
293 LLVMConstInt(LLVMInt64TypeInContext(context), i as u64, 0), cstr!("dist_elem")),
294 cstr!("dist_acc")
295 );
296 }
297 (dist, dist_vec)
298}
299
300pub(crate) unsafe fn correct_postion_vector(context: LLVMContextRef, builder: LLVMBuilderRef,
301 position: LLVMValueRef, other_position: LLVMValueRef, cutoff_skin: f64, domain: &Domain)
302-> LLVMValueRef {
303 let double_type = LLVMDoubleTypeInContext(context);
304 match domain {
305 Domain::Dim2 { .. } => unimplemented!(),
306 Domain::Dim3 { x, y, z } => {
307 assert!(x.size() > 2.*cutoff_skin);
309 assert!(y.size() > 2.*cutoff_skin);
310 assert!(z.size() > 2.*cutoff_skin);
311 let raw_dist_vec = evaluate_binop(context, builder, position, other_position, BinaryOperator::Sub).unwrap();
314 let cmp1 = LLVMBuildFCmp(builder, LLVMRealPredicate::LLVMRealOGT, raw_dist_vec, LLVMConstVector([
315 LLVMConstReal(double_type, cutoff_skin),
316 LLVMConstReal(double_type, cutoff_skin),
317 LLVMConstReal(double_type, cutoff_skin),
318 ].as_mut_ptr(), 3), cstr!("cmp_gt"));
319 let cmp2 = LLVMBuildFCmp(builder, LLVMRealPredicate::LLVMRealOLT, raw_dist_vec, LLVMConstVector([
320 LLVMConstReal(double_type, -cutoff_skin),
321 LLVMConstReal(double_type, -cutoff_skin),
322 LLVMConstReal(double_type, -cutoff_skin),
323 ].as_mut_ptr(), 3), cstr!("cmp_lt"));
324
325 let cmp1 = LLVMBuildUIToFP(builder, cmp1, LLVMVectorType(double_type, 3), cstr!("cmp_gt_dbl"));
334 let cmp2 = LLVMBuildUIToFP(builder, cmp2, LLVMVectorType(double_type, 3), cstr!(""));
335 let cmp2 = LLVMBuildFMul(builder, cmp2, LLVMConstVector([
336 LLVMConstReal(double_type, -1.),
337 LLVMConstReal(double_type, -1.),
338 LLVMConstReal(double_type, -1.),
339 ].as_mut_ptr(), 3), cstr!("cmp_lt_dbl"));
340 let cmp = LLVMBuildFAdd(builder, cmp1, cmp2, cstr!("cmp"));
342 let corr_vec = LLVMBuildFMul(builder, cmp, LLVMConstVector([
344 LLVMConstReal(double_type, x.size()),
345 LLVMConstReal(double_type, y.size()),
346 LLVMConstReal(double_type, z.size()),
347 ].as_mut_ptr(), 3), cstr!("correction_vec"));
348 evaluate_binop(context, builder, other_position, corr_vec, BinaryOperator::Add).unwrap()
349 }
350 }
351}