use std::{collections::HashMap, ffi::CString};
use anyhow::{anyhow, Result};
use crate::{codegen::util::unwrap_usize_constant, cstr, cstring, parser::{self, BinaryOperator}, runtime::FunctionIndex};
use super::{LLSymbolValue, analysis::{SymbolTable}};
use llvm_sys::{LLVMTypeKind, prelude::*};
use llvm_sys::core::*;
unsafe fn int_asmd(builder: LLVMBuilderRef, op: parser::BinaryOperator, lhs: LLVMValueRef, rhs: LLVMValueRef) -> LLVMValueRef {
match op {
parser::BinaryOperator::Add => {
LLVMBuildAdd(builder, lhs, rhs, cstr!("tmp_add"))
}
parser::BinaryOperator::Sub => {
LLVMBuildSub(builder, lhs, rhs, cstr!("tmp_sub"))
}
parser::BinaryOperator::Mul => {
LLVMBuildMul(builder, lhs, rhs, cstr!("tmp_mul"))
}
parser::BinaryOperator::Div => {
LLVMBuildSDiv(builder, lhs, rhs, cstr!("tmp_div"))
}
}
}
unsafe fn float_asmd(builder: LLVMBuilderRef, op: parser::BinaryOperator, lhs: LLVMValueRef, rhs: LLVMValueRef) -> LLVMValueRef {
match op {
parser::BinaryOperator::Add => {
LLVMBuildFAdd(builder, lhs, rhs, cstr!("tmp_add"))
}
parser::BinaryOperator::Sub => {
LLVMBuildFSub(builder, lhs, rhs, cstr!("tmp_sub"))
}
parser::BinaryOperator::Mul => {
LLVMBuildFMul(builder, lhs, rhs, cstr!("tmp_mul"))
}
parser::BinaryOperator::Div => {
LLVMBuildFDiv(builder, lhs, rhs, cstr!("tmp_div"))
}
}
}
unsafe fn vector_asmd(builder: LLVMBuilderRef, op: parser::BinaryOperator, lhs: LLVMValueRef, rhs: LLVMValueRef) -> LLVMValueRef {
let lhs_type = LLVMTypeOf(lhs);
let rhs_type = LLVMTypeOf(rhs);
let lhs_elem_type = LLVMGetElementType(lhs_type);
let rhs_elem_type = LLVMGetElementType(rhs_type);
match (LLVMGetTypeKind(lhs_elem_type), LLVMGetTypeKind(rhs_elem_type)) {
(LLVMTypeKind::LLVMDoubleTypeKind, LLVMTypeKind::LLVMDoubleTypeKind) => {
float_asmd(builder, op, lhs, rhs)
}
(LLVMTypeKind::LLVMIntegerTypeKind, LLVMTypeKind::LLVMIntegerTypeKind) => {
int_asmd(builder, op, lhs, rhs)
}
(LLVMTypeKind::LLVMDoubleTypeKind, LLVMTypeKind::LLVMIntegerTypeKind) |
(LLVMTypeKind::LLVMIntegerTypeKind, LLVMTypeKind::LLVMDoubleTypeKind) => {
unimplemented!("Cannot combine double and int vectors yet")
}
_ => panic!()
}
}
pub(crate) unsafe fn promote_to_vector(context: LLVMContextRef, builder: LLVMBuilderRef, llval: LLVMValueRef, length: u32) -> LLVMValueRef {
let llkind = LLVMGetTypeKind(LLVMTypeOf(llval));
match llkind {
LLVMTypeKind::LLVMVectorTypeKind => {
llval
}
LLVMTypeKind::LLVMArrayTypeKind => {
assert_eq!(LLVMGetArrayLength(LLVMTypeOf(llval)), length);
let elem_kind = LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(llval)));
let mut llvec = match elem_kind {
LLVMTypeKind::LLVMDoubleTypeKind => {
LLVMGetUndef(
LLVMVectorType(LLVMDoubleTypeInContext(context), length)
)
}
LLVMTypeKind::LLVMIntegerTypeKind => {
LLVMGetUndef(
LLVMVectorType(LLVMInt64TypeInContext(context), length)
)
}
_ => unreachable!()
};
for i in 0..length {
let llelement = LLVMBuildExtractValue(builder, llval, i, cstr!(""));
llvec = LLVMBuildInsertElement(builder, llvec, llelement,
LLVMConstInt(LLVMInt32TypeInContext(context), i as u64, 0),
cstring!("tmp_promoted_vec"));
}
llvec
}
LLVMTypeKind::LLVMDoubleTypeKind | LLVMTypeKind::LLVMIntegerTypeKind => {
let mut llvec = match llkind {
LLVMTypeKind::LLVMDoubleTypeKind => {
LLVMGetUndef(
LLVMVectorType(LLVMDoubleTypeInContext(context), length)
)
}
LLVMTypeKind::LLVMIntegerTypeKind => {
LLVMGetUndef(
LLVMVectorType(LLVMInt64TypeInContext(context), length)
)
}
_ => unreachable!()
};
for i in 0..length {
llvec = LLVMBuildInsertElement(builder, llvec, llval,
LLVMConstInt(LLVMInt32TypeInContext(context), i as u64, 0),
cstring!("tmp_promoted_vec"));
}
llvec
}
_ => panic!()
}
}
pub(crate) unsafe fn convert_to_scalar_or_array(context: LLVMContextRef, builder: LLVMBuilderRef, llval: LLVMValueRef) -> LLVMValueRef {
let llkind = LLVMGetTypeKind(LLVMTypeOf(llval));
match llkind {
LLVMTypeKind::LLVMDoubleTypeKind | LLVMTypeKind::LLVMIntegerTypeKind | LLVMTypeKind::LLVMArrayTypeKind => {
llval
}
LLVMTypeKind::LLVMVectorTypeKind => {
let length = LLVMGetVectorSize(LLVMTypeOf(llval));
let elem_type = LLVMGetElementType(LLVMTypeOf(llval));
let mut llarray = LLVMGetUndef(LLVMArrayType(elem_type, length));
for i in 0..length {
let llelem = LLVMBuildExtractElement(builder, llval,
LLVMConstInt(LLVMInt32TypeInContext(context), i as u64, 0), cstr!(""));
llarray = LLVMBuildInsertValue(builder, llarray, llelem, i, cstr!("tmp_vec2array"))
}
llarray
}
_ => panic!()
}
}
pub(crate) unsafe fn evaluate_expression(context: LLVMContextRef, builder: LLVMBuilderRef, expr: &parser::Expression,
symbol_table: &SymbolTable<LLSymbolValue>,
namespace_symbols: &HashMap<&String, HashMap<String, LLVMValueRef>>,
function_index: &FunctionIndex,
callback_target_ptrptr: LLVMValueRef) -> Result<LLVMValueRef>
{
match expr {
parser::Expression::Atom(atom) => {
match &atom {
parser::Atom::Variable(name) => {
let value = symbol_table.resolve_symbol(name)
.ok_or(anyhow!("Cannot resolve symbol {}", name))?;
let llname = format!("load_{}", name);
let llptr = match value.value.as_ref()
.expect(&format!("Incomplete symbol table: symbol for {} has no value", name))
{
LLSymbolValue::SimplePointer(llptr) => { *llptr }
LLSymbolValue::ParticleMember { local_ptr, .. } =>
{
local_ptr.expect(&format!("Local pointer not set for symbol {}", name))
}
LLSymbolValue::Function { .. } => panic!("Cannot evaluate function as atomic expression")
};
let llval = LLVMBuildLoad(builder, llptr, cstring!(llname));
Ok(llval)
}
parser::Atom::Literal(literal) => {
match literal {
parser::Literal::Double(val) => {
let typ = LLVMDoubleTypeInContext(context);
Ok(LLVMConstReal(typ, *val))
}
parser::Literal::Int64(val) => {
let typ = LLVMInt64TypeInContext(context);
Ok(LLVMConstInt(typ, *val as u64, 1))
}
}
}
parser::Atom::NamespaceVariable{ namespace, name } => {
let llptr = namespace_symbols.get(namespace)
.ok_or(anyhow!("Namespace {} not found", namespace))?
.get(name).map(|x| x)
.ok_or(anyhow!("Symbol {} not found in namespace {}", name, namespace))?;
Ok(LLVMBuildLoad(builder, *llptr, cstring!(format!("{}_{}", namespace, name))))
}
}
}
parser::Expression::BinaryOperation(binop) => {
let lhs = evaluate_expression(context, builder, binop.lhs.as_ref(),
symbol_table, namespace_symbols, function_index, callback_target_ptrptr)?;
let rhs = evaluate_expression(context, builder, binop.rhs.as_ref(),
symbol_table, namespace_symbols, function_index, callback_target_ptrptr)?;
evaluate_binop(context, builder, lhs, rhs, binop.op)
}
parser::Expression::FunctionCall(call) => {
let symbol = symbol_table.resolve_symbol(&call.fn_name)
.ok_or(anyhow!("Cannot find function {}", call.fn_name))?;
let symbol_value = match symbol.value.as_ref()
.ok_or(anyhow!("No value for symbol for function {}", call.fn_name))?
{
LLSymbolValue::Function(val) => val,
_ => return Err(anyhow!("Invalid symbol value for function {}", call.fn_name))
};
let func = function_index.get(symbol_value.function_id)
.expect(&format!("Internal: cannot find function {} in index", call.fn_name));
let mut parameter_vals = call.parameters.iter()
.map(|expr| evaluate_expression(context, builder, expr,
symbol_table, namespace_symbols, function_index, callback_target_ptrptr))
.collect::<Result<Vec<_>>>()?;
if func.needs_callback_target_ptr() {
let callback_target_ptr = LLVMBuildLoad(builder, callback_target_ptrptr, cstr!("callback_target_ptr"));
parameter_vals.insert(0, callback_target_ptr);
}
Ok(func.build_call(context, builder, symbol_value, parameter_vals)?)
}
parser::Expression::Block(_) => { unimplemented!() }
parser::Expression::Indexing(indexop) => {
let array_val = evaluate_expression(context, builder,
&parser::Expression::Atom(parser::Atom::Variable(indexop.array.clone())),
symbol_table, namespace_symbols, function_index, callback_target_ptrptr)?;
let index = unwrap_usize_constant(&indexop.index)? as u32;
let name = format!("{}_indexed_at_{}", indexop.array, index);
Ok(LLVMBuildExtractValue(builder, array_val,
index, cstring!(name)))
},
parser::Expression::AdHocArray(adhocarray) => {
if adhocarray.elements.is_empty() {
unimplemented!("Empty arrays are not supported yet");
}
let elements = adhocarray.elements.iter()
.map(|element_expr| evaluate_expression(
context, builder, element_expr, symbol_table, namespace_symbols,
function_index, callback_target_ptrptr)
).collect::<Result<Vec<_>>>()?;
let lltyp = LLVMTypeOf(elements[0]);
let mut result = LLVMGetUndef(LLVMArrayType(lltyp, elements.len() as u32));
for (i, element) in elements.into_iter().enumerate() {
let element = convert_to_scalar_or_array(context, builder, element);
let name = format!("adhoc_array_{}", i);
result = LLVMBuildInsertValue(builder, result, element, i as u32, cstring!(name));
}
Ok(result)
},
}
}
pub(crate) unsafe fn evaluate_binop(context: LLVMContextRef, builder: LLVMBuilderRef, lhs: LLVMValueRef, rhs: LLVMValueRef, op: BinaryOperator) -> Result<LLVMValueRef> {
let lhs_type = LLVMTypeOf(lhs);
let rhs_type = LLVMTypeOf(rhs);
let lhs_type_kind = LLVMGetTypeKind(lhs_type);
let rhs_type_kind = LLVMGetTypeKind(rhs_type);
Ok(match (lhs_type_kind, rhs_type_kind) {
(LLVMTypeKind::LLVMDoubleTypeKind, LLVMTypeKind::LLVMDoubleTypeKind) => {
float_asmd(builder, op, lhs, rhs)
}
(LLVMTypeKind::LLVMIntegerTypeKind, LLVMTypeKind::LLVMIntegerTypeKind) => {
int_asmd(builder, op, lhs, rhs)
}
(LLVMTypeKind::LLVMVectorTypeKind, LLVMTypeKind::LLVMVectorTypeKind) => {
if LLVMGetVectorSize(lhs_type) != LLVMGetVectorSize(rhs_type) {
return Err(anyhow!("Cannot combine vectors of different length in a binary operation"));
}
vector_asmd(builder, op, lhs, rhs)
}
(LLVMTypeKind::LLVMVectorTypeKind, _) | (_, LLVMTypeKind::LLVMVectorTypeKind) => {
let vec_length = match (lhs_type_kind, rhs_type_kind) {
(LLVMTypeKind::LLVMVectorTypeKind, _) => {
assert!(matches!(rhs_type_kind, LLVMTypeKind::LLVMDoubleTypeKind
| LLVMTypeKind::LLVMIntegerTypeKind | LLVMTypeKind::LLVMArrayTypeKind));
LLVMGetVectorSize(lhs_type)
}
(_, LLVMTypeKind::LLVMVectorTypeKind) => {
assert!(matches!(lhs_type_kind, LLVMTypeKind::LLVMDoubleTypeKind
| LLVMTypeKind::LLVMIntegerTypeKind | LLVMTypeKind::LLVMArrayTypeKind));
LLVMGetVectorSize(rhs_type)
}
_ => unreachable!()
};
let lhs = promote_to_vector(context, builder, lhs, vec_length);
let rhs = promote_to_vector(context, builder, rhs, vec_length);
vector_asmd(builder, op, lhs, rhs)
}
(LLVMTypeKind::LLVMArrayTypeKind, LLVMTypeKind::LLVMArrayTypeKind) => {
let lhs = promote_to_vector(context, builder, lhs, LLVMGetArrayLength(lhs_type));
let rhs = promote_to_vector(context, builder, rhs, LLVMGetArrayLength(rhs_type));
evaluate_binop(context, builder, lhs, rhs, op)?
}
(LLVMTypeKind::LLVMArrayTypeKind, _) => {
let lhs = promote_to_vector(context, builder, lhs, LLVMGetArrayLength(lhs_type));
evaluate_binop(context, builder, lhs, rhs, op)?
}
(_ , LLVMTypeKind::LLVMArrayTypeKind) => {
let rhs = promote_to_vector(context, builder, rhs, LLVMGetArrayLength(rhs_type));
evaluate_binop(context, builder, lhs, rhs, op)?
}
(LLVMTypeKind::LLVMDoubleTypeKind, LLVMTypeKind::LLVMIntegerTypeKind) |
(LLVMTypeKind::LLVMIntegerTypeKind, LLVMTypeKind::LLVMDoubleTypeKind) => {
unimplemented!("Cannot combine double and int scalars yet")
}
_ => panic!()
})
}