fips_md/codegen/
llhelpers.rs

1//! Macros and related things for making llvm-sys a little less painful
2
3
4use std::{collections::HashMap, ffi::{CStr, CString}, sync::RwLockReadGuard, unimplemented};
5
6use anyhow::{anyhow, Result};
7
8use libc::{c_uint, c_ulonglong};
9use llvm_sys::prelude::*;
10use llvm_sys::core::*;
11use llvm_sys::error::*;
12use llvm_sys::{LLVMIntPredicate, LLVMRealPredicate, LLVMTypeKind};
13
14use crate::{codegen::promote_to_vector, parser::{FipsType, CompileTimeConstant, BinaryOperator}, runtime::{MemberData, ParticleIndex, ParticleStore, Domain}};
15use crate::utils::FipsValue;
16
17use super::{NeighborList, evaluate_binop};
18
19// Convert LLVMErrorRef to Result
20pub(crate) unsafe fn llvm_errorref_to_result<T>(context: &str, error: LLVMErrorRef) -> Result<T> {
21    let message = CStr::from_ptr(LLVMGetErrorMessage(error)).to_str()
22            .expect("Error while decoding LLVM error (how ironic)");
23    Err(anyhow!("{}: {}", context, message))
24}
25
26// Add \0 to string reference for C intercompatibility
27#[macro_export]
28macro_rules! cstr {
29    ($string:expr) => {
30        concat!($string, "\0").as_ptr() as *const _
31    }
32}
33
34#[macro_export]
35macro_rules! cstring {
36    ($string:expr) => {
37        CString::new($string).expect("String conversion failed")
38            .as_bytes().as_ptr() as *const _
39    }
40}
41
42/// Get (type, dimensions) tuple for LLVM type corresponding to FIPS type
43pub(crate) unsafe fn get_llvm_type_dims(context: LLVMContextRef, typ: &FipsType) -> Result<(LLVMTypeRef, Vec<usize>)> {
44    Ok(match typ {
45        FipsType::Double => (LLVMDoubleTypeInContext(context), vec![]),
46        FipsType::Int64 => (LLVMInt64TypeInContext(context), vec![]),
47        FipsType::Array { typ, length } => {
48            let length = match length {
49                CompileTimeConstant::Literal(value) | CompileTimeConstant::Substituted(value, _) => value,
50                CompileTimeConstant::Identifier(name) => return Err(anyhow!("Unresolved identifer {}", name))
51            };
52            let (subtype, mut subdims) = get_llvm_type_dims(context, typ)?;
53            subdims.insert(0, *length);
54            (subtype, subdims)
55        } 
56    })
57}
58
59pub(crate) unsafe fn create_global_const_double(module: LLVMModuleRef, name: String, value: f64) -> LLVMValueRef {
60    let context = LLVMGetModuleContext(module);
61    let typ = LLVMDoubleTypeInContext(context);
62    let llval = LLVMAddGlobal(module, typ, cstring!(name));
63    let initializer = LLVMConstReal(typ, value);
64    LLVMSetGlobalConstant(llval, 1);
65    LLVMSetInitializer(llval, initializer);
66    llval
67}
68
69pub(crate) unsafe fn create_global_const_int64(module: LLVMModuleRef, name: String, value: i64) -> LLVMValueRef {
70    let context = LLVMGetModuleContext(module);
71    let typ = LLVMInt64TypeInContext(context);
72    let llval = LLVMAddGlobal(module, typ, cstring!(name));
73    let initializer = LLVMConstInt(typ, value as c_ulonglong, 0); // TODO: Sign extend?
74    LLVMSetGlobalConstant(llval, 1);
75    LLVMSetInitializer(llval, initializer);
76    llval
77}
78
79pub(crate) unsafe fn create_global_const(module: LLVMModuleRef, name: String, value: FipsValue) -> LLVMValueRef {
80    match value {
81        FipsValue::Int64(value) => create_global_const_int64(module, name, value),
82        FipsValue::Double(value) => create_global_const_double(module, name, value)
83    }
84}
85
86unsafe fn __create_global_ptr(context: LLVMContextRef, module: LLVMModuleRef, name: String, scalar_type: LLVMTypeRef,
87    ptr: usize, stride: u32) -> LLVMValueRef 
88{
89    let element_type = match stride {
90        1 => scalar_type,
91        n => LLVMArrayType(scalar_type, n as c_uint)
92    };
93    let typ = LLVMPointerType(element_type, 0);
94    let llval = LLVMAddGlobal(module, typ, cstring!(name));
95    // Here we just assume the native pointer size if 64 bit
96    let initializer = LLVMConstIntToPtr(LLVMConstInt(LLVMInt64TypeInContext(context), ptr as c_ulonglong, 0), typ);
97    LLVMSetGlobalConstant(llval, 1);
98    LLVMSetInitializer(llval, initializer);
99    llval
100}
101
102pub(crate) unsafe fn create_global_ptr(module: LLVMModuleRef, name: String, typ: &FipsType, ptr: usize) -> Result<LLVMValueRef> {
103    let context = LLVMGetModuleContext(module);
104    let (lltype, dims) = get_llvm_type_dims(context, typ)?;
105    Ok(match dims.len() {
106        0 => __create_global_ptr(context, module, name, lltype, ptr, 1),
107        1 => __create_global_ptr(context, module, name, lltype, ptr, dims[0] as u32),
108        _ => unimplemented!("No multidim support for now"),
109    })
110}
111
112pub(crate) unsafe fn create_local_ptr(module: LLVMModuleRef, builder: LLVMBuilderRef, name: String, typ: &FipsType) -> Result<LLVMValueRef> {
113    let context = LLVMGetModuleContext(module);
114    let (lltyp, dims) = get_llvm_type_dims(context, typ)?;
115    let typ = match dims.len() {
116        0 => lltyp,
117        1 => LLVMArrayType(lltyp, dims[0] as c_uint),
118        _ => unimplemented!("No multidim support for now"),
119    };
120    Ok(LLVMBuildAlloca(builder, typ, cstring!(name)))
121}
122
123// TODO: Integrate these with the above
124pub(crate) unsafe fn fips_value_2_llvm(module: LLVMModuleRef, value: &FipsValue) -> LLVMValueRef {
125    match value {
126        FipsValue::Int64(value) => {
127            let context = LLVMGetModuleContext(module);
128            let typ = LLVMInt64TypeInContext(context);
129            LLVMConstInt(typ, *value as c_ulonglong, 0)
130        },
131        FipsValue::Double(value) => {
132            let context = LLVMGetModuleContext(module);
133            let typ = LLVMDoubleTypeInContext(context);
134            LLVMConstReal(typ, *value)
135        }
136    }
137}
138
139unsafe fn __fips_ptr_2_llvm(context: LLVMContextRef, scalar_type: LLVMTypeRef,
140    ptr: usize, stride: u32) -> LLVMValueRef 
141{
142    let element_type = match stride {
143        1 => scalar_type,
144        n => LLVMArrayType(scalar_type, n as c_uint)
145    };
146    let typ = LLVMPointerType(element_type, 0);
147    // Here we just assume the native pointer size if 64 bit
148    LLVMConstIntToPtr(LLVMConstInt(LLVMInt64TypeInContext(context), ptr as c_ulonglong, 0), typ)
149}
150
151pub(crate) unsafe fn fips_ptr_2_llvm(module: LLVMModuleRef, typ: &FipsType, ptr: usize) -> Result<LLVMValueRef> {
152    let context = LLVMGetModuleContext(module);
153    let (lltype, dims) = get_llvm_type_dims(context, typ)?;
154    Ok(match dims.len() {
155        0 => __fips_ptr_2_llvm(context, lltype, ptr, 1),
156        1 => __fips_ptr_2_llvm(context, lltype, ptr, dims[0] as u32),
157        _ => unimplemented!("No multidim support for now"),
158    })
159}
160pub(crate) unsafe fn create_neighbor_member_values<'a>(module: LLVMModuleRef, members: Vec<&'a str>, 
161    neighbor_list: &RwLockReadGuard<NeighborList>, particle_index: &ParticleIndex, particle_store: &ParticleStore) 
162-> HashMap<&'a str, Vec<LLVMValueRef>> {
163    members.iter().map(|member_name| {
164        let llvals = neighbor_list.pos_blocks.iter()
165            .map(move |(particle_id, index_range)| {
166                let particle_def = particle_index.get(*particle_id).unwrap();
167                let particle_data = particle_store.get_particle(*particle_id).unwrap();
168                let (member_id, member_def) = match particle_def.get_member_by_name(member_name) {
169                    None => { return None }
170                    Some(x) => x
171                };
172                let member_data = particle_data.borrow_member(&member_id).unwrap();
173                Some(match &*member_data {
174                    MemberData::Uniform(value) => {
175                        fips_value_2_llvm(module, value)
176                    }
177                    MemberData::PerParticle{data, ..} => {
178                        // TODO: Less unwrap
179                        let offset = index_range.start * member_def.get_member_size().unwrap();
180                        let data_ptr = data.as_ptr() as usize + offset;
181                        fips_ptr_2_llvm(module, member_def.get_type(), data_ptr).unwrap()
182                    }
183                })
184            }).collect::<Vec<_>>();
185        // TODO: For now, do no allow mixed allocation for neighbor members
186        let mut lltyp_check = None;
187        for llval in &llvals {
188            if let Some(llval) = llval {
189                lltyp_check = Some(LLVMTypeOf(*llval));
190                break;
191            }
192        }
193        let lltyp_check = lltyp_check.unwrap();
194        // TODO: Clean this up; fun fact: on debug builds the unwrap_or value is constructed no matter what
195        // so this crashes if we just try to generate NULL pointers for non-pointer types
196        let llvals = llvals.into_iter().map(|llval| llval.unwrap())//llval.unwrap_or(LLVMConstPointerNull(lltyp_check)))
197            .collect::<Vec<_>>();
198        for llval in &llvals {
199            assert_eq!(lltyp_check, LLVMTypeOf(*llval));
200        }
201        (*member_name, llvals)
202    }).collect::<HashMap<_,_>>()
203}
204
205pub(crate) unsafe fn build_loop(context: LLVMContextRef, builder: LLVMBuilderRef, block_loop_body: LLVMBasicBlockRef,
206    block_after_loop: LLVMBasicBlockRef, loop_index_ptr: LLVMValueRef, end_index: LLVMValueRef) -> LLVMBasicBlockRef {
207    // Create loop blocks
208    let block_loop_check = LLVMInsertBasicBlockInContext(context, block_after_loop, cstr!("loop_check"));
209    // let block_loop_body = LLVMInsertBasicBlockInContext(context, block_after_loop, cstr!("loop_body"));
210    let block_loop_increment = LLVMInsertBasicBlockInContext(context, block_after_loop, cstr!("loop_increment"));
211
212    // Create loop check
213    LLVMBuildBr(builder, block_loop_check);
214    LLVMPositionBuilderAtEnd(builder, block_loop_check);
215    let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
216    let end_index = match LLVMGetTypeKind(LLVMTypeOf(end_index)) {
217        LLVMTypeKind::LLVMPointerTypeKind => { LLVMBuildLoad(builder, end_index, cstr!("end_index")) }
218        _ => end_index
219    };
220    let comparison = LLVMBuildICmp(builder, LLVMIntPredicate::LLVMIntULT, loop_index, end_index, cstr!("loop_check"));
221    LLVMBuildCondBr(builder, comparison, block_loop_body, block_after_loop);
222
223    // Create loop increment
224    LLVMPositionBuilderAtEnd(builder, block_loop_increment);
225    let loop_index = LLVMBuildLoad(builder, loop_index_ptr, cstr!("loop_var_val"));
226    let llone = LLVMConstInt(LLVMInt64TypeInContext(context), 1, 0);
227    let incremented_index = LLVMBuildAdd(builder, loop_index, llone, cstr!("incremented_val"));
228    LLVMBuildStore(builder, incremented_index, loop_index_ptr);
229    LLVMBuildBr(builder, block_loop_check);
230
231    block_loop_increment
232}
233
234pub(crate) unsafe fn llmultiply_by_minus_one(context: LLVMContextRef, builder: LLVMBuilderRef, value: LLVMValueRef) -> LLVMValueRef {
235    match LLVMGetTypeKind(LLVMTypeOf(value)) {
236        LLVMTypeKind::LLVMArrayTypeKind => {
237            let length = LLVMGetArrayLength(LLVMTypeOf(value));
238            llmultiply_by_minus_one(context, builder, promote_to_vector(context, builder, value, length))
239        }
240        LLVMTypeKind::LLVMVectorTypeKind => {
241            let scalar = scalar_minus_one(context, LLVMGetElementType(LLVMTypeOf(value)));
242            let length = LLVMGetVectorSize(LLVMTypeOf(value));
243            let llminus = LLVMConstVector(vec![scalar; length as usize].as_mut_ptr(), length);
244            match LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(value))) {
245                LLVMTypeKind::LLVMDoubleTypeKind => {
246                    LLVMBuildFMul(builder, llminus, value, cstr!("negated"))
247                }
248                LLVMTypeKind::LLVMIntegerTypeKind => {
249                    LLVMBuildMul(builder, llminus, value, cstr!("negated"))
250                }
251                _ => panic!()
252            }
253        }
254        LLVMTypeKind::LLVMDoubleTypeKind => {
255            let llminus = scalar_minus_one(context, LLVMTypeOf(value));
256            LLVMBuildFMul(builder, llminus, value, cstr!("negated"))
257        }
258        LLVMTypeKind::LLVMIntegerTypeKind => {
259            let llminus = scalar_minus_one(context, LLVMTypeOf(value));
260            LLVMBuildMul(builder, llminus, value, cstr!("negated"))
261        }
262        _ => panic!()
263    }
264}
265
266unsafe fn scalar_minus_one(context: LLVMContextRef, typ: LLVMTypeRef) -> LLVMValueRef {
267    let double_type = LLVMDoubleTypeInContext(context);
268    let int64_type = LLVMInt64TypeInContext(context);
269    match LLVMGetTypeKind(typ) {
270        LLVMTypeKind::LLVMDoubleTypeKind => {
271            LLVMConstReal(double_type, -1.0)
272        }
273        LLVMTypeKind::LLVMIntegerTypeKind => { 
274            LLVMConstInt(int64_type, std::mem::transmute(-1i64), 1)
275        }
276        _ => panic!()
277    }
278}
279
280pub(crate) unsafe fn calculate_distance_sqr_and_vec(context: LLVMContextRef, builder: LLVMBuilderRef,
281    pos_1: LLVMValueRef, pos_2: LLVMValueRef) 
282-> (LLVMValueRef, LLVMValueRef) {
283    // Cheat a bit and create a pseudo parser expression for the difference
284    let dist_vec = evaluate_binop(context, builder, pos_2, pos_1, BinaryOperator::Sub).unwrap();
285    // Now square the elements of the dist vector
286    let dist_vec_sqr = evaluate_binop(context, builder, dist_vec, dist_vec, BinaryOperator::Mul).unwrap();
287    // Finally add all elements of the distance vector squared
288    let mut dist = LLVMBuildExtractElement(builder, dist_vec_sqr, 
289        LLVMConstInt(LLVMInt64TypeInContext(context), 0, 0), cstr!("dist_acc"));
290    for i in 1..LLVMGetVectorSize(LLVMTypeOf(dist_vec_sqr)) {
291        dist = LLVMBuildFAdd(builder, dist,
292            LLVMBuildExtractElement(builder, dist_vec_sqr, 
293                LLVMConstInt(LLVMInt64TypeInContext(context), i as u64, 0), cstr!("dist_elem")),
294            cstr!("dist_acc")
295        );
296    }
297    (dist, dist_vec)
298}
299
300pub(crate) unsafe fn correct_postion_vector(context: LLVMContextRef, builder: LLVMBuilderRef,
301    position: LLVMValueRef, other_position: LLVMValueRef, cutoff_skin: f64, domain: &Domain)
302-> LLVMValueRef {
303    let double_type = LLVMDoubleTypeInContext(context);
304    match domain {
305        Domain::Dim2 { .. } => unimplemented!(),
306        Domain::Dim3 { x, y, z } => {
307            // Assume domain size > 2*cutoff
308            assert!(x.size() > 2.*cutoff_skin);
309            assert!(y.size() > 2.*cutoff_skin);
310            assert!(z.size() > 2.*cutoff_skin);
311            // If this condition holds, the raw distance vector must have a component larger than the cutoff length
312            // or smaller than then negative cutoff length if it is incorrectly mirrored
313            let raw_dist_vec = evaluate_binop(context, builder, position, other_position, BinaryOperator::Sub).unwrap();
314            let cmp1 = LLVMBuildFCmp(builder, LLVMRealPredicate::LLVMRealOGT, raw_dist_vec, LLVMConstVector([
315                LLVMConstReal(double_type, cutoff_skin),
316                LLVMConstReal(double_type, cutoff_skin),
317                LLVMConstReal(double_type, cutoff_skin),
318            ].as_mut_ptr(), 3), cstr!("cmp_gt"));
319            let cmp2 = LLVMBuildFCmp(builder, LLVMRealPredicate::LLVMRealOLT, raw_dist_vec, LLVMConstVector([
320                LLVMConstReal(double_type, -cutoff_skin),
321                LLVMConstReal(double_type, -cutoff_skin),
322                LLVMConstReal(double_type, -cutoff_skin),
323            ].as_mut_ptr(), 3), cstr!("cmp_lt"));
324
325            // let valx = LLVMBuildExtractElement(builder, cmp1, LLVMConstInt(int8_type, 0, 0), cstr!(""));
326            // let valy = LLVMBuildExtractElement(builder, cmp1, LLVMConstInt(int8_type, 1, 0), cstr!(""));
327            // let valz = LLVMBuildExtractElement(builder, cmp1, LLVMConstInt(int8_type, 2, 0), cstr!(""));
328            // LLVMBuildCall(builder, print_func_u64, [valx].as_mut_ptr(), 1, cstr!(""));
329            // LLVMBuildCall(builder, print_func_u64, [valy].as_mut_ptr(), 1, cstr!(""));
330            // LLVMBuildCall(builder, print_func_u64, [valz].as_mut_ptr(), 1, cstr!(""));
331
332            // Cast to double
333            let cmp1 = LLVMBuildUIToFP(builder, cmp1, LLVMVectorType(double_type, 3), cstr!("cmp_gt_dbl"));
334            let cmp2 = LLVMBuildUIToFP(builder, cmp2, LLVMVectorType(double_type, 3), cstr!(""));
335            let cmp2 = LLVMBuildFMul(builder, cmp2, LLVMConstVector([
336                LLVMConstReal(double_type, -1.),
337                LLVMConstReal(double_type, -1.),
338                LLVMConstReal(double_type, -1.),
339            ].as_mut_ptr(), 3), cstr!("cmp_lt_dbl"));
340            // Add
341            let cmp = LLVMBuildFAdd(builder, cmp1, cmp2, cstr!("cmp"));
342            // Multiply domain length
343            let corr_vec = LLVMBuildFMul(builder, cmp, LLVMConstVector([
344                LLVMConstReal(double_type, x.size()),
345                LLVMConstReal(double_type, y.size()),
346                LLVMConstReal(double_type, z.size()),
347            ].as_mut_ptr(), 3), cstr!("correction_vec"));
348            evaluate_binop(context, builder, other_position, corr_vec, BinaryOperator::Add).unwrap()
349        }
350    }
351}