use std::rc::Rc;
use anyhow::Result;
use delegate::delegate;
use hugr_core::extension::ExtensionId;
use hugr_core::types::{SumType, Type, TypeName};
use inkwell::types::FunctionType;
use inkwell::{context::Context, types::BasicTypeEnum};
use crate::custom::types::{LLVMCustomTypeFn, LLVMTypeMapping};
pub use crate::sum::LLVMSumType;
use crate::utils::type_map::TypeMap;
pub type HugrFuncType = hugr_core::types::Signature;
pub type HugrType = Type;
pub type HugrSumType = SumType;
#[derive(Clone)]
pub struct TypingSession<'c, 'a> {
iw_context: &'c Context,
type_converter: Rc<TypeConverter<'a>>,
}
impl<'c, 'a> TypingSession<'c, 'a> {
delegate! {
to self.type_converter.clone() {
pub fn llvm_type(&self, [self.clone()], hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
pub fn llvm_func_type(&self, [self.clone()], hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
pub fn llvm_sum_type(&self, [self.clone()], hugr_type: HugrSumType) -> Result<LLVMSumType<'c>>;
}
}
pub fn new(iw_context: &'c Context, type_converter: Rc<TypeConverter<'a>>) -> Self {
Self {
iw_context,
type_converter,
}
}
pub fn iw_context(&self) -> &'c Context {
self.iw_context
}
}
#[derive(Default)]
pub struct TypeConverter<'a>(TypeMap<'a, LLVMTypeMapping<'a>>);
impl<'a> TypeConverter<'a> {
pub(super) fn custom_type(
&mut self,
custom_type: (ExtensionId, TypeName),
handler: impl LLVMCustomTypeFn<'a>,
) {
self.0.set_callback(custom_type, handler);
}
pub fn llvm_type<'c>(
self: Rc<Self>,
context: TypingSession<'c, 'a>,
hugr_type: &HugrType,
) -> Result<BasicTypeEnum<'c>> {
self.0.map_type(hugr_type, context)
}
pub fn llvm_func_type<'c>(
self: Rc<Self>,
context: TypingSession<'c, 'a>,
hugr_type: &HugrFuncType,
) -> Result<FunctionType<'c>> {
self.0.map_function_type(hugr_type, context)
}
pub fn llvm_sum_type<'c>(
self: Rc<Self>,
context: TypingSession<'c, 'a>,
hugr_type: HugrSumType,
) -> Result<LLVMSumType<'c>> {
self.0.map_sum_type(&hugr_type, context)
}
pub fn session<'c>(self: Rc<Self>, iw_context: &'c Context) -> TypingSession<'c, 'a> {
TypingSession::new(iw_context, self)
}
}
#[cfg(test)]
#[allow(drop_bounds)]
pub mod test {
use hugr_core::{
std_extensions::arithmetic::int_types::INT_TYPES,
type_row,
types::{SumType, Type},
};
use insta::assert_snapshot;
use rstest::rstest;
use crate::{extension::int::add_int_extensions, test::*, types::HugrFuncType};
#[rstest]
#[case(0,HugrFuncType::new(type_row!(Type::new_unit_sum(2)), type_row!()))]
#[case(1, HugrFuncType::new(Type::new_unit_sum(1), Type::new_unit_sum(3)))]
#[case(2,HugrFuncType::new(vec![], vec![Type::new_unit_sum(1), Type::new_unit_sum(1)]))]
fn func_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] ft: HugrFuncType) {
assert_snapshot!(
"func_type_to_llvm",
llvm_ctx.get_typing_session().llvm_func_type(&ft).unwrap(),
&ft.to_string()
)
}
#[rstest]
#[case(0, SumType::new_unary(0))]
#[case(1, SumType::new_unary(1))]
#[case(2,SumType::new([vec![Type::new_unit_sum(0), Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), Type::new_unit_sum(3)]]))]
#[case(3, SumType::new_unary(2))]
fn sum_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] st: SumType) {
assert_snapshot!(
"sum_type_to_llvm",
llvm_ctx
.get_typing_session()
.llvm_sum_type(st.clone())
.unwrap(),
&st.to_string()
)
}
#[rstest]
#[case(0, INT_TYPES[0].clone())]
#[case(1, INT_TYPES[3].clone())]
#[case(2, INT_TYPES[4].clone())]
#[case(3, INT_TYPES[5].clone())]
#[case(4, INT_TYPES[6].clone())]
#[case(5, Type::new_sum([vec![INT_TYPES[2].clone()]]))]
#[case(6, Type::new_sum([vec![INT_TYPES[6].clone(),Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), INT_TYPES[2].clone()]]))]
#[case(7, Type::new_function(HugrFuncType::new(type_row!(Type::new_unit_sum(2)), Type::new_unit_sum(3))))]
fn ext_types(#[case] _id: i32, #[with(_id)] mut llvm_ctx: TestContext, #[case] t: Type) {
llvm_ctx.add_extensions(add_int_extensions);
assert_snapshot!(
"type_to_llvm",
llvm_ctx.get_typing_session().llvm_type(&t).unwrap(),
&t.to_string()
);
}
}