hugr_llvm/
types.rs

1use std::rc::Rc;
2
3use anyhow::Result;
4use delegate::delegate;
5use hugr_core::extension::ExtensionId;
6use hugr_core::types::{SumType, Type, TypeName};
7use inkwell::types::FunctionType;
8use inkwell::{context::Context, types::BasicTypeEnum};
9
10use crate::custom::types::{LLVMCustomTypeFn, LLVMTypeMapping};
11pub use crate::sum::LLVMSumType;
12use crate::utils::type_map::TypeMap;
13
14/// A type alias for a hugr function type. We use this to disambiguate from
15/// the LLVM [FunctionType].
16pub type HugrFuncType = hugr_core::types::Signature;
17
18/// A type alias for a hugr type. We use this to disambiguate from LLVM types.
19pub type HugrType = Type;
20
21/// A type alias for a hugr sum type.
22pub type HugrSumType = SumType;
23
24/// A type that holds [Rc] shared pointers to everything needed to convert from
25/// a hugr [HugrType] to an LLVM [Type](inkwell::types).
26#[derive(Clone)]
27pub struct TypingSession<'c, 'a> {
28    iw_context: &'c Context,
29    type_converter: Rc<TypeConverter<'a>>,
30}
31
32impl<'c, 'a> TypingSession<'c, 'a> {
33    delegate! {
34        to self.type_converter.clone() {
35            /// Convert a [HugrType] into an LLVM [Type](BasicTypeEnum).
36            pub fn llvm_type(&self, [self.clone()], hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
37            /// Convert a [HugrFuncType] into an LLVM [FunctionType].
38            pub fn llvm_func_type(&self, [self.clone()], hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
39            /// Convert a hugr [HugrSumType] into an LLVM [LLVMSumType].
40            pub fn llvm_sum_type(&self, [self.clone()], hugr_type: HugrSumType) -> Result<LLVMSumType<'c>>;
41        }
42    }
43
44    /// Creates a new `TypingSession`.
45    pub fn new(iw_context: &'c Context, type_converter: Rc<TypeConverter<'a>>) -> Self {
46        Self {
47            iw_context,
48            type_converter,
49        }
50    }
51
52    /// Returns a reference to the inner [Context].
53    pub fn iw_context(&self) -> &'c Context {
54        self.iw_context
55    }
56}
57
58#[derive(Default)]
59pub struct TypeConverter<'a>(TypeMap<'a, LLVMTypeMapping<'a>>);
60
61impl<'a> TypeConverter<'a> {
62    pub(super) fn custom_type(
63        &mut self,
64        custom_type: (ExtensionId, TypeName),
65        handler: impl LLVMCustomTypeFn<'a>,
66    ) {
67        self.0.set_callback(custom_type, handler);
68    }
69
70    pub fn llvm_type<'c>(
71        self: Rc<Self>,
72        context: TypingSession<'c, 'a>,
73        hugr_type: &HugrType,
74    ) -> Result<BasicTypeEnum<'c>> {
75        self.0.map_type(hugr_type, context)
76    }
77
78    pub fn llvm_func_type<'c>(
79        self: Rc<Self>,
80        context: TypingSession<'c, 'a>,
81        hugr_type: &HugrFuncType,
82    ) -> Result<FunctionType<'c>> {
83        self.0.map_function_type(hugr_type, context)
84    }
85
86    pub fn llvm_sum_type<'c>(
87        self: Rc<Self>,
88        context: TypingSession<'c, 'a>,
89        hugr_type: HugrSumType,
90    ) -> Result<LLVMSumType<'c>> {
91        self.0.map_sum_type(&hugr_type, context)
92    }
93
94    pub fn session<'c>(self: Rc<Self>, iw_context: &'c Context) -> TypingSession<'c, 'a> {
95        TypingSession::new(iw_context, self)
96    }
97}
98
99#[cfg(test)]
100#[allow(drop_bounds)]
101pub mod test {
102
103    use hugr_core::{
104        std_extensions::arithmetic::int_types::INT_TYPES,
105        type_row,
106        types::{SumType, Type},
107    };
108
109    use insta::assert_snapshot;
110    use rstest::rstest;
111
112    use crate::{extension::int::add_int_extensions, test::*, types::HugrFuncType};
113
114    #[rstest]
115    #[case(0,HugrFuncType::new(type_row!(Type::new_unit_sum(2)), type_row!()))]
116    #[case(1, HugrFuncType::new(Type::new_unit_sum(1), Type::new_unit_sum(3)))]
117    #[case(2,HugrFuncType::new(vec![], vec![Type::new_unit_sum(1), Type::new_unit_sum(1)]))]
118    fn func_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] ft: HugrFuncType) {
119        assert_snapshot!(
120            "func_type_to_llvm",
121            llvm_ctx.get_typing_session().llvm_func_type(&ft).unwrap(),
122            &ft.to_string()
123        )
124    }
125
126    #[rstest]
127    #[case(0, SumType::new_unary(0))]
128    #[case(1, SumType::new_unary(1))]
129    #[case(2,SumType::new([vec![Type::new_unit_sum(4), Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), Type::new_unit_sum(3)]]))]
130    #[case(3, SumType::new_unary(2))]
131    fn sum_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] st: SumType) {
132        assert_snapshot!(
133            "sum_type_to_llvm",
134            llvm_ctx
135                .get_typing_session()
136                .llvm_sum_type(st.clone())
137                .unwrap(),
138            &st.to_string()
139        )
140    }
141
142    #[rstest]
143    #[case(0, INT_TYPES[0].clone())]
144    #[case(1, INT_TYPES[3].clone())]
145    #[case(2, INT_TYPES[4].clone())]
146    #[case(3, INT_TYPES[5].clone())]
147    #[case(4, INT_TYPES[6].clone())]
148    #[case(5, Type::new_sum([vec![INT_TYPES[2].clone()]]))]
149    #[case(6, Type::new_sum([vec![INT_TYPES[6].clone(),Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), INT_TYPES[2].clone()]]))]
150    #[case(7, Type::new_function(HugrFuncType::new(type_row!(Type::new_unit_sum(2)), Type::new_unit_sum(3))))]
151    fn ext_types(#[case] _id: i32, #[with(_id)] mut llvm_ctx: TestContext, #[case] t: Type) {
152        llvm_ctx.add_extensions(add_int_extensions);
153        assert_snapshot!(
154            "type_to_llvm",
155            llvm_ctx.get_typing_session().llvm_type(&t).unwrap(),
156            &t.to_string()
157        );
158    }
159}