hugr_llvm/
types.rs

1use std::rc::Rc;
2
3use anyhow::Result;
4use delegate::delegate;
5use hugr_core::extension::ExtensionId;
6use hugr_core::types::{SumType, Type, TypeName};
7use inkwell::types::FunctionType;
8use inkwell::{context::Context, types::BasicTypeEnum};
9
10use crate::custom::types::{LLVMCustomTypeFn, LLVMTypeMapping};
11pub use crate::sum::LLVMSumType;
12use crate::utils::type_map::TypeMap;
13
14/// A type alias for a hugr function type. We use this to disambiguate from
15/// the LLVM [`FunctionType`].
16pub type HugrFuncType = hugr_core::types::Signature;
17
18/// A type alias for a hugr type. We use this to disambiguate from LLVM types.
19pub type HugrType = Type;
20
21/// A type alias for a hugr sum type.
22pub type HugrSumType = SumType;
23
24/// A type that holds [Rc] shared pointers to everything needed to convert from
25/// a hugr [`HugrType`] to an LLVM [Type](inkwell::types).
26#[derive(Clone)]
27pub struct TypingSession<'c, 'a> {
28    iw_context: &'c Context,
29    type_converter: Rc<TypeConverter<'a>>,
30}
31
32impl<'c, 'a> TypingSession<'c, 'a> {
33    delegate! {
34        to self.type_converter.clone() {
35            /// Convert a [HugrType] into an LLVM [Type](BasicTypeEnum).
36            pub fn llvm_type(&self, [self.clone()], hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
37            /// Convert a [HugrFuncType] into an LLVM [FunctionType].
38            pub fn llvm_func_type(&self, [self.clone()], hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
39            /// Convert a hugr [HugrSumType] into an LLVM [LLVMSumType].
40            pub fn llvm_sum_type(&self, [self.clone()], hugr_type: HugrSumType) -> Result<LLVMSumType<'c>>;
41        }
42    }
43
44    /// Creates a new `TypingSession`.
45    #[must_use]
46    pub fn new(iw_context: &'c Context, type_converter: Rc<TypeConverter<'a>>) -> Self {
47        Self {
48            iw_context,
49            type_converter,
50        }
51    }
52
53    /// Returns a reference to the inner [Context].
54    #[must_use]
55    pub fn iw_context(&self) -> &'c Context {
56        self.iw_context
57    }
58}
59
60#[derive(Default)]
61pub struct TypeConverter<'a>(TypeMap<'a, LLVMTypeMapping<'a>>);
62
63impl<'a> TypeConverter<'a> {
64    pub(super) fn custom_type(
65        &mut self,
66        custom_type: (ExtensionId, TypeName),
67        handler: impl LLVMCustomTypeFn<'a>,
68    ) {
69        self.0.set_callback(custom_type, handler);
70    }
71
72    pub fn llvm_type<'c>(
73        self: Rc<Self>,
74        context: TypingSession<'c, 'a>,
75        hugr_type: &HugrType,
76    ) -> Result<BasicTypeEnum<'c>> {
77        self.0.map_type(hugr_type, context)
78    }
79
80    pub fn llvm_func_type<'c>(
81        self: Rc<Self>,
82        context: TypingSession<'c, 'a>,
83        hugr_type: &HugrFuncType,
84    ) -> Result<FunctionType<'c>> {
85        self.0.map_function_type(hugr_type, context)
86    }
87
88    pub fn llvm_sum_type<'c>(
89        self: Rc<Self>,
90        context: TypingSession<'c, 'a>,
91        hugr_type: HugrSumType,
92    ) -> Result<LLVMSumType<'c>> {
93        self.0.map_sum_type(&hugr_type, context)
94    }
95
96    #[must_use]
97    pub fn session<'c>(self: Rc<Self>, iw_context: &'c Context) -> TypingSession<'c, 'a> {
98        TypingSession::new(iw_context, self)
99    }
100}
101
102#[cfg(test)]
103#[allow(drop_bounds)]
104pub mod test {
105
106    use hugr_core::{
107        std_extensions::arithmetic::int_types::INT_TYPES,
108        type_row,
109        types::{SumType, Type},
110    };
111
112    use insta::assert_snapshot;
113    use rstest::rstest;
114
115    use crate::{test::*, types::HugrFuncType};
116
117    #[rstest]
118    #[case(0,HugrFuncType::new(type_row!(Type::new_unit_sum(2)), type_row!()))]
119    #[case(1, HugrFuncType::new(Type::new_unit_sum(1), Type::new_unit_sum(3)))]
120    #[case(2,HugrFuncType::new(vec![], vec![Type::new_unit_sum(1), Type::new_unit_sum(1)]))]
121    fn func_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] ft: HugrFuncType) {
122        assert_snapshot!(
123            "func_type_to_llvm",
124            llvm_ctx.get_typing_session().llvm_func_type(&ft).unwrap(),
125            &ft.to_string()
126        );
127    }
128
129    #[rstest]
130    #[case(0, SumType::new_unary(0))]
131    #[case(1, SumType::new_unary(1))]
132    #[case(2,SumType::new([vec![Type::new_unit_sum(4), Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), Type::new_unit_sum(3)]]))]
133    #[case(3, SumType::new_unary(2))]
134    fn sum_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] st: SumType) {
135        assert_snapshot!(
136            "sum_type_to_llvm",
137            llvm_ctx
138                .get_typing_session()
139                .llvm_sum_type(st.clone())
140                .unwrap(),
141            &st.to_string()
142        );
143    }
144
145    #[rstest]
146    #[case(0, INT_TYPES[0].clone())]
147    #[case(1, INT_TYPES[3].clone())]
148    #[case(2, INT_TYPES[4].clone())]
149    #[case(3, INT_TYPES[5].clone())]
150    #[case(4, INT_TYPES[6].clone())]
151    #[case(5, Type::new_sum([vec![INT_TYPES[2].clone()]]))]
152    #[case(6, Type::new_sum([vec![INT_TYPES[6].clone(),Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), INT_TYPES[2].clone()]]))]
153    #[case(7, Type::new_function(HugrFuncType::new(type_row!(Type::new_unit_sum(2)), Type::new_unit_sum(3))))]
154    fn ext_types(#[case] _id: i32, #[with(_id)] mut llvm_ctx: TestContext, #[case] t: Type) {
155        use crate::CodegenExtsBuilder;
156
157        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_int_extensions);
158        assert_snapshot!(
159            "type_to_llvm",
160            llvm_ctx.get_typing_session().llvm_type(&t).unwrap(),
161            &t.to_string()
162        );
163    }
164}