use pliron::{
arg_err,
builtin::{
op_interfaces::{OneResultInterface, SymbolTableInterface},
types::{IntegerType, Signedness},
},
context::{Context, Ptr},
identifier::Identifier,
irbuild::{inserter::Inserter, listener::InsertionListener},
result::Result,
symbol_table::SymbolTableCollection,
r#type::TypeObj,
value::Value,
};
use crate::{
op_interfaces::CastOpInterface,
ops::{FuncOp, GepIndex, GetElementPtrOp, PtrToIntOp, ZeroOp},
types::{FuncType, PointerType, VoidType},
};
#[derive(Debug, thiserror::Error)]
pub enum LookupOrInsertFunctionError {
#[error("Symbol '{0}' found but is not a function")]
SymbolNotFunction(Identifier),
#[error("Existing function '{0}' has a different type than the one being inserted")]
FunctionTypeMismatch(Identifier),
}
pub fn lookup_or_insert_function(
ctx: &mut Context,
symbol_table_collection: &mut SymbolTableCollection,
symbol_table_op: Box<dyn SymbolTableInterface>,
name: Identifier,
return_type: Ptr<TypeObj>,
param_types: Vec<Ptr<TypeObj>>,
is_var_arg: bool,
) -> Result<FuncOp> {
let loc = symbol_table_op.loc(ctx);
let func_ty = FuncType::get(ctx, return_type, param_types, is_var_arg);
let symbol_table = symbol_table_collection.get_symbol_table(ctx, symbol_table_op.clone());
if let Some(func) = symbol_table.lookup(&name) {
if let Some(func_op) = func.as_any().downcast_ref::<FuncOp>() {
let existing_func_ty = func_op.get_type(ctx);
if existing_func_ty != func_ty {
return arg_err!(loc, LookupOrInsertFunctionError::FunctionTypeMismatch(name));
}
Ok(*func_op)
} else {
arg_err!(loc, LookupOrInsertFunctionError::SymbolNotFunction(name))
}
} else {
let func = FuncOp::new(ctx, name.clone(), func_ty);
symbol_table.insert(ctx, Box::new(func), None)?;
Ok(func)
}
}
pub fn get_size_type(ctx: &mut Context) -> Ptr<TypeObj> {
IntegerType::get(ctx, 64, Signedness::Signless).into()
}
pub fn lookup_or_create_malloc_fn(
ctx: &mut Context,
symbol_table_collection: &mut SymbolTableCollection,
symbol_table_op: Box<dyn SymbolTableInterface>,
) -> Result<FuncOp> {
let size_ty = get_size_type(ctx);
lookup_or_insert_function(
ctx,
symbol_table_collection,
symbol_table_op,
"malloc".try_into().unwrap(),
PointerType::get(ctx).into(),
vec![size_ty],
false,
)
}
pub fn lookup_or_create_free_fn(
ctx: &mut Context,
symbol_table_collection: &mut SymbolTableCollection,
symbol_table_op: Box<dyn SymbolTableInterface>,
) -> Result<FuncOp> {
let ptr_ty = PointerType::get(ctx).into();
lookup_or_insert_function(
ctx,
symbol_table_collection,
symbol_table_op,
"free".try_into().unwrap(),
VoidType::get(ctx).into(),
vec![ptr_ty],
false,
)
}
pub fn compute_type_size_in_bytes<L: InsertionListener, I: Inserter<L>>(
ctx: &mut Context,
inserter: &mut I,
ty: Ptr<TypeObj>,
) -> Value {
let size_ty = get_size_type(ctx);
let pointer_ty = PointerType::get(ctx).into();
let zero_op = ZeroOp::new(ctx, pointer_ty);
inserter.append_op(ctx, zero_op);
let gep_op = GetElementPtrOp::new(
ctx,
zero_op.get_result(ctx),
vec![GepIndex::Constant(1)],
ty,
);
inserter.append_op(ctx, gep_op);
let ptr_to_int_op = PtrToIntOp::new(ctx, gep_op.get_result(ctx), size_ty);
inserter.append_op(ctx, ptr_to_int_op);
ptr_to_int_op.get_result(ctx)
}
#[cfg(test)]
mod tests {
use expect_test::expect;
use pliron::{
builtin::{
op_interfaces::{
CallOpCallable, OneResultInterface, SingleBlockRegionInterface, SymbolOpInterface,
},
ops::ModuleOp,
types::FP64Type,
},
context::Context,
init_env_logger_for_tests,
irbuild::{
inserter::{IRInserter, Inserter, OpInsertionPoint},
listener::DummyListener,
},
op::{Op, verify_op},
result::ExpectOk,
};
use crate::{
function_call_utils::{
compute_type_size_in_bytes, get_size_type, lookup_or_create_free_fn,
lookup_or_create_malloc_fn,
},
llvm_sys::{core::LLVMContext, lljit::LLVMLLJIT, target},
ops::{CallOp, FuncOp, ReturnOp},
to_llvm_ir::convert_module,
types::FuncType,
};
#[test]
fn test_malloc_and_free_integration() {
init_env_logger_for_tests!();
let mut ctx = Context::new();
let mut symbol_table_collection = pliron::symbol_table::SymbolTableCollection::new();
let module = ModuleOp::new(&mut ctx, "test_module".try_into().unwrap());
let module_box = Box::new(module);
let malloc_fn =
lookup_or_create_malloc_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
.expect("Failed to create malloc function");
let free_fn =
lookup_or_create_free_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
.expect("Failed to create free function");
assert_eq!(
malloc_fn.get_symbol_name(&ctx),
"malloc".try_into().unwrap()
);
assert_eq!(free_fn.get_symbol_name(&ctx), "free".try_into().unwrap());
let malloc_fn_2 =
lookup_or_create_malloc_fn(&mut ctx, &mut symbol_table_collection, module_box.clone())
.expect("Failed to get malloc function again");
assert!(
malloc_fn == malloc_fn_2,
"Expected to get the same malloc function on second lookup"
);
let return_type = get_size_type(&mut ctx);
let func_ty = FuncType::get(&mut ctx, return_type, vec![], false);
let main_fn = FuncOp::new(&mut ctx, "main".try_into().unwrap(), func_ty);
main_fn
.get_operation()
.insert_at_front(module.get_body(&ctx, 0), &ctx);
let entry = main_fn.get_or_create_entry_block(&mut ctx);
let mut inserter = IRInserter::<DummyListener>::new(OpInsertionPoint::AtBlockEnd(entry));
let fp_ty = FP64Type::get(&ctx);
let fp_ty_size = compute_type_size_in_bytes(&mut ctx, &mut inserter, fp_ty.into());
let callee = CallOpCallable::Direct(malloc_fn.get_symbol_name(&ctx));
let callee_ty = malloc_fn.get_type(&ctx);
let args = vec![fp_ty_size];
let malloc_call = CallOp::new(&mut ctx, callee, callee_ty, args);
inserter.append_op(&ctx, malloc_call);
let ptr_result = malloc_call.get_result(&ctx);
let free_callee = CallOpCallable::Direct(free_fn.get_symbol_name(&ctx));
let free_callee_ty = free_fn.get_type(&ctx);
let free_args = vec![ptr_result];
let free_call = CallOp::new(&mut ctx, free_callee, free_callee_ty, free_args);
inserter.append_op(&ctx, free_call);
let ret_op = ReturnOp::new(&mut ctx, Some(fp_ty_size));
inserter.append_op(&ctx, ret_op);
verify_op(&module, &ctx).expect_ok(&ctx);
let llvm_ctx = LLVMContext::default();
let llvm_ir = convert_module(&ctx, &llvm_ctx, module).expect_ok(&ctx);
expect![[r#"
; ModuleID = 'test_module'
source_filename = "test_module"
define i64 @main() {
entry_block2v1:
%v3 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (double, ptr null, i32 1) to i64))
call void @free(ptr %v3)
ret i64 ptrtoint (ptr getelementptr (double, ptr null, i32 1) to i64)
}
declare ptr @malloc(i64)
declare void @free(ptr)
"#]]
.assert_eq(&llvm_ir.to_string());
llvm_ir.verify().expect("Generated LLVM IR is invalid");
target::initialize_native().expect("Failed to initialize native target for JIT");
let jit = LLVMLLJIT::new_with_default_builder().expect("Failed to create LLJIT instance");
jit.add_module(llvm_ir)
.expect("Failed to add module to JIT");
let main_addr = jit
.lookup_symbol("main")
.expect("Failed to lookup 'main' symbol");
let main_fn = unsafe { std::mem::transmute::<u64, fn() -> u64>(main_addr) };
let fp_ty_size = main_fn();
assert_eq!(
fp_ty_size, 8,
"Expected size of double type to be 8 bytes, got {}",
fp_ty_size
);
}
}