use anyhow::{Result, anyhow};
use delegate::delegate;
use hugr_core::{
HugrView, Node, Visibility,
ops::{FuncDecl, FuncDefn, OpType},
types::PolyFuncType,
};
use inkwell::{
builder::Builder,
context::Context,
intrinsics::Intrinsic,
module::{Linkage, Module},
types::{AnyType, BasicType, BasicTypeEnum, FunctionType},
values::{BasicValueEnum, CallSiteValue, FunctionValue, GlobalValue},
};
use std::{collections::HashSet, rc::Rc};
use crate::types::{HugrFuncType, HugrSumType, HugrType, TypingSession};
use crate::{custom::CodegenExtsMap, types::LLVMSumType, utils::fat::FatNode};
pub mod args;
pub mod func;
pub mod libc;
pub mod namer;
pub mod ops;
pub use args::EmitOpArgs;
pub use func::{EmitFuncContext, RowPromise};
pub use namer::Namer;
pub use ops::emit_value;
pub struct EmitModuleContext<'c, 'a, H>
where
'a: 'c,
{
iw_context: &'c Context,
module: Module<'c>,
extensions: Rc<CodegenExtsMap<'a, H>>,
namer: Rc<Namer>,
}
impl<'c, 'a, H> EmitModuleContext<'c, 'a, H> {
delegate! {
to self.typing_session() {
pub fn llvm_type(&self, hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
pub fn llvm_func_type(&self, hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
pub fn llvm_sum_type(&self, sum_type: HugrSumType) -> Result<LLVMSumType<'c>>;
}
to self.namer {
pub fn name_func(&self, name: impl AsRef<str>, node: Node) -> String;
}
}
pub fn iw_context(&self) -> &'c Context {
self.iw_context
}
pub fn new(
iw_context: &'c Context,
module: Module<'c>,
namer: Rc<Namer>,
extensions: Rc<CodegenExtsMap<'a, H>>,
) -> Self {
Self {
iw_context,
module,
extensions,
namer,
}
}
pub fn module(&self) -> &Module<'c> {
&self.module
}
pub fn extensions(&self) -> Rc<CodegenExtsMap<'a, H>> {
self.extensions.clone()
}
pub fn typing_session(&self) -> TypingSession<'c, 'a> {
self.extensions
.type_converter
.clone()
.session(self.iw_context)
}
fn get_func_impl(
&self,
name: impl AsRef<str>,
func_ty: FunctionType<'c>,
linkage: Option<Linkage>,
) -> Result<FunctionValue<'c>> {
let func = self
.module()
.get_function(name.as_ref())
.unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
if func.get_type() != func_ty {
Err(anyhow!(
"Function '{}' has wrong type: expected: {func_ty} actual: {}",
name.as_ref(),
func.get_type()
))?;
}
Ok(func)
}
fn get_hugr_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &PolyFuncType,
visibility: &Visibility,
) -> Result<FunctionValue<'c>> {
let func_ty = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(func_ty)?;
let name = self.name_func(name, node);
match visibility {
Visibility::Public => self.get_func_impl(name, llvm_func_ty, Some(Linkage::External)),
Visibility::Private => self.get_func_impl(name, llvm_func_ty, Some(Linkage::Private)),
_ => self.get_func_impl(name, llvm_func_ty, None),
}
}
pub fn get_func_defn<'hugr>(
&self,
node: FatNode<'hugr, FuncDefn, H>,
) -> Result<FunctionValue<'c>>
where
H: HugrView<Node = Node>,
{
self.get_hugr_func_impl(
node.func_name(),
node.node(),
node.signature(),
node.visibility(),
)
}
pub fn get_func_decl<'hugr>(
&self,
node: FatNode<'hugr, FuncDecl, H>,
) -> Result<FunctionValue<'c>>
where
H: HugrView<Node = Node>,
{
self.get_hugr_func_impl(
node.func_name(),
node.node(),
node.signature(),
node.visibility(),
)
}
pub fn get_extern_func(
&self,
symbol: impl AsRef<str>,
typ: FunctionType<'c>,
) -> Result<FunctionValue<'c>> {
self.get_func_impl(symbol, typ, Some(Linkage::External))
}
pub fn get_global(
&self,
symbol: impl AsRef<str>,
typ: impl BasicType<'c>,
constant: bool,
) -> Result<GlobalValue<'c>> {
let symbol = symbol.as_ref();
let typ = typ.as_basic_type_enum();
if let Some(global) = self.module().get_global(symbol) {
let global_type = global.get_value_type();
if global_type != typ.as_any_type_enum() {
Err(anyhow!(
"Global '{symbol}' has wrong type: expected: {typ} actual: {global_type}"
))?;
}
if global.is_constant() != constant {
Err(anyhow!(
"Global '{symbol}' has wrong constant-ness: expected: {constant} actual: {}",
global.is_constant()
))?;
}
Ok(global)
} else {
let global = self.module().add_global(typ, None, symbol.as_ref());
global.set_constant(constant);
Ok(global)
}
}
pub fn finish(self) -> Module<'c> {
self.module
}
}
type EmissionSet = HashSet<Node>;
pub struct EmitHugr<'c, 'a, H>
where
'a: 'c,
{
emitted: EmissionSet,
module_context: EmitModuleContext<'c, 'a, H>,
}
impl<'c, 'a, H: HugrView<Node = Node>> EmitHugr<'c, 'a, H> {
delegate! {
to self.module_context {
pub fn iw_context(&self) -> &'c Context;
pub fn module(&self) -> &Module<'c>;
}
}
pub fn new(
iw_context: &'c Context,
module: Module<'c>,
namer: Rc<Namer>,
extensions: Rc<CodegenExtsMap<'a, H>>,
) -> Self {
assert_eq!(iw_context, &module.get_context());
Self {
emitted: Default::default(),
module_context: EmitModuleContext::new(iw_context, module, namer, extensions),
}
}
pub fn emit_func(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<Self> {
let mut worklist: EmissionSet = [node.node()].into_iter().collect();
let pop = |wl: &mut EmissionSet| wl.iter().next().copied().map(|x| wl.take(&x).unwrap());
while let Some(next_node) = pop(&mut worklist) {
use crate::utils::fat::FatExt as _;
let Some(func) = node.hugr().try_fat(next_node) else {
panic!(
"emit_func: node in worklist was not a FuncDefn: {:?}",
node.hugr().get_optype(next_node)
)
};
let (new_self, new_tasks) = self.emit_func_impl(func)?;
self = new_self;
worklist.extend(new_tasks.into_iter());
}
Ok(self)
}
pub fn emit_module(mut self, node: FatNode<'_, hugr_core::ops::Module, H>) -> Result<Self> {
for c in node.children() {
match c.as_ref() {
OpType::FuncDefn(fd) => {
let fat_ot = c.into_ot(fd);
self = self.emit_func(fat_ot)?;
}
OpType::FuncDecl(_) => (),
OpType::Const(_) => (),
_ => Err(anyhow!("Module has invalid child: {c}"))?,
}
}
Ok(self)
}
fn emit_func_impl(mut self, node: FatNode<'_, FuncDefn, H>) -> Result<(Self, EmissionSet)> {
if !self.emitted.insert(node.node()) {
return Ok((self, EmissionSet::default()));
}
let func = self.module_context.get_func_defn(node)?;
let mut func_ctx = EmitFuncContext::new(self.module_context, func)?;
let ret_rmb = func_ctx.new_row_mail_box(node.signature().body().output.iter(), "ret")?;
ops::emit_dataflow_parent(
&mut func_ctx,
EmitOpArgs {
node,
inputs: func.get_params(),
outputs: ret_rmb.promise(),
},
)?;
let builder = func_ctx.builder();
match &ret_rmb.read::<Vec<_>>(builder, [])?[..] {
[] => builder.build_return(None)?,
[x] => builder.build_return(Some(x))?,
xs => builder.build_aggregate_return(xs)?,
};
let (mctx, todos) = func_ctx.finish()?;
self.module_context = mctx;
Ok((self, todos))
}
pub fn finish(self) -> Module<'c> {
self.module_context.finish()
}
}
pub fn deaggregate_call_result<'c>(
builder: &Builder<'c>,
call_result: CallSiteValue<'c>,
num_results: usize,
) -> Result<Vec<BasicValueEnum<'c>>> {
let call_result = call_result.try_as_basic_value();
Ok(match num_results as u32 {
0 => {
let _ = call_result.expect_instruction("void");
vec![]
}
1 => vec![call_result.expect_basic("non-void")],
n => {
let return_struct = call_result.expect_basic("non-void").into_struct_value();
(0..n)
.map(|i| builder.build_extract_value(return_struct, i, ""))
.collect::<Result<Vec<_>, _>>()?
}
})
}
pub fn get_intrinsic<'c>(
module: &Module<'c>,
name: impl AsRef<str>,
args: impl AsRef<[BasicTypeEnum<'c>]>,
) -> Result<FunctionValue<'c>> {
let (name, args) = (name.as_ref(), args.as_ref());
let intrinsic = Intrinsic::find(name).ok_or(anyhow!("Failed to find intrinsic: '{name}'"))?;
intrinsic
.get_declaration(module, args.as_ref())
.ok_or(anyhow!(
"failed to get_declaration for intrinsic '{name}' with args '{args:?}'"
))
}
#[cfg(any(test, feature = "test-utils"))]
pub mod test;