1use std::rc::Rc;
2
3use anyhow::Result;
4use delegate::delegate;
5use hugr_core::extension::ExtensionId;
6use hugr_core::types::{SumType, Type, TypeName};
7use inkwell::types::FunctionType;
8use inkwell::{context::Context, types::BasicTypeEnum};
9
10use crate::custom::types::{LLVMCustomTypeFn, LLVMTypeMapping};
11pub use crate::sum::LLVMSumType;
12use crate::utils::type_map::TypeMap;
13
14pub type HugrFuncType = hugr_core::types::Signature;
17
18pub type HugrType = Type;
20
21pub type HugrSumType = SumType;
23
24#[derive(Clone)]
27pub struct TypingSession<'c, 'a> {
28 iw_context: &'c Context,
29 type_converter: Rc<TypeConverter<'a>>,
30}
31
32impl<'c, 'a> TypingSession<'c, 'a> {
33 delegate! {
34 to self.type_converter.clone() {
35 pub fn llvm_type(&self, [self.clone()], hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
37 pub fn llvm_func_type(&self, [self.clone()], hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
39 pub fn llvm_sum_type(&self, [self.clone()], hugr_type: HugrSumType) -> Result<LLVMSumType<'c>>;
41 }
42 }
43
44 #[must_use]
46 pub fn new(iw_context: &'c Context, type_converter: Rc<TypeConverter<'a>>) -> Self {
47 Self {
48 iw_context,
49 type_converter,
50 }
51 }
52
53 #[must_use]
55 pub fn iw_context(&self) -> &'c Context {
56 self.iw_context
57 }
58}
59
60#[derive(Default)]
61pub struct TypeConverter<'a>(TypeMap<'a, LLVMTypeMapping<'a>>);
62
63impl<'a> TypeConverter<'a> {
64 pub(super) fn custom_type(
65 &mut self,
66 custom_type: (ExtensionId, TypeName),
67 handler: impl LLVMCustomTypeFn<'a>,
68 ) {
69 self.0.set_callback(custom_type, handler);
70 }
71
72 pub fn llvm_type<'c>(
73 self: Rc<Self>,
74 context: TypingSession<'c, 'a>,
75 hugr_type: &HugrType,
76 ) -> Result<BasicTypeEnum<'c>> {
77 self.0.map_type(hugr_type, context)
78 }
79
80 pub fn llvm_func_type<'c>(
81 self: Rc<Self>,
82 context: TypingSession<'c, 'a>,
83 hugr_type: &HugrFuncType,
84 ) -> Result<FunctionType<'c>> {
85 self.0.map_function_type(hugr_type, context)
86 }
87
88 pub fn llvm_sum_type<'c>(
89 self: Rc<Self>,
90 context: TypingSession<'c, 'a>,
91 hugr_type: HugrSumType,
92 ) -> Result<LLVMSumType<'c>> {
93 self.0.map_sum_type(&hugr_type, context)
94 }
95
96 #[must_use]
97 pub fn session<'c>(self: Rc<Self>, iw_context: &'c Context) -> TypingSession<'c, 'a> {
98 TypingSession::new(iw_context, self)
99 }
100}
101
102#[cfg(test)]
103#[allow(drop_bounds)]
104pub mod test {
105
106 use hugr_core::{
107 std_extensions::arithmetic::int_types::INT_TYPES,
108 type_row,
109 types::{SumType, Type},
110 };
111
112 use insta::assert_snapshot;
113 use rstest::rstest;
114
115 use crate::{test::*, types::HugrFuncType};
116
117 #[rstest]
118 #[case(0,HugrFuncType::new(type_row!(Type::new_unit_sum(2)), type_row!()))]
119 #[case(1, HugrFuncType::new(Type::new_unit_sum(1), Type::new_unit_sum(3)))]
120 #[case(2,HugrFuncType::new(vec![], vec![Type::new_unit_sum(1), Type::new_unit_sum(1)]))]
121 fn func_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] ft: HugrFuncType) {
122 assert_snapshot!(
123 "func_type_to_llvm",
124 llvm_ctx.get_typing_session().llvm_func_type(&ft).unwrap(),
125 &ft.to_string()
126 );
127 }
128
129 #[rstest]
130 #[case(0, SumType::new_unary(0))]
131 #[case(1, SumType::new_unary(1))]
132 #[case(2,SumType::new([vec![Type::new_unit_sum(4), Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), Type::new_unit_sum(3)]]))]
133 #[case(3, SumType::new_unary(2))]
134 fn sum_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] st: SumType) {
135 assert_snapshot!(
136 "sum_type_to_llvm",
137 llvm_ctx
138 .get_typing_session()
139 .llvm_sum_type(st.clone())
140 .unwrap(),
141 &st.to_string()
142 );
143 }
144
145 #[rstest]
146 #[case(0, INT_TYPES[0].clone())]
147 #[case(1, INT_TYPES[3].clone())]
148 #[case(2, INT_TYPES[4].clone())]
149 #[case(3, INT_TYPES[5].clone())]
150 #[case(4, INT_TYPES[6].clone())]
151 #[case(5, Type::new_sum([vec![INT_TYPES[2].clone()]]))]
152 #[case(6, Type::new_sum([vec![INT_TYPES[6].clone(),Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), INT_TYPES[2].clone()]]))]
153 #[case(7, Type::new_function(HugrFuncType::new(type_row!(Type::new_unit_sum(2)), Type::new_unit_sum(3))))]
154 fn ext_types(#[case] _id: i32, #[with(_id)] mut llvm_ctx: TestContext, #[case] t: Type) {
155 use crate::CodegenExtsBuilder;
156
157 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_int_extensions);
158 assert_snapshot!(
159 "type_to_llvm",
160 llvm_ctx.get_typing_session().llvm_type(&t).unwrap(),
161 &t.to_string()
162 );
163 }
164}