1use std::rc::Rc;
2
3use anyhow::Result;
4use delegate::delegate;
5use hugr_core::extension::ExtensionId;
6use hugr_core::types::{SumType, Type, TypeName};
7use inkwell::types::FunctionType;
8use inkwell::{context::Context, types::BasicTypeEnum};
9
10use crate::custom::types::{LLVMCustomTypeFn, LLVMTypeMapping};
11pub use crate::sum::LLVMSumType;
12use crate::utils::type_map::TypeMap;
13
14pub type HugrFuncType = hugr_core::types::Signature;
17
18pub type HugrType = Type;
20
21pub type HugrSumType = SumType;
23
24#[derive(Clone)]
27pub struct TypingSession<'c, 'a> {
28 iw_context: &'c Context,
29 type_converter: Rc<TypeConverter<'a>>,
30}
31
32impl<'c, 'a> TypingSession<'c, 'a> {
33 delegate! {
34 to self.type_converter.clone() {
35 pub fn llvm_type(&self, [self.clone()], hugr_type: &HugrType) -> Result<BasicTypeEnum<'c>>;
37 pub fn llvm_func_type(&self, [self.clone()], hugr_type: &HugrFuncType) -> Result<FunctionType<'c>>;
39 pub fn llvm_sum_type(&self, [self.clone()], hugr_type: HugrSumType) -> Result<LLVMSumType<'c>>;
41 }
42 }
43
44 pub fn new(iw_context: &'c Context, type_converter: Rc<TypeConverter<'a>>) -> Self {
46 Self {
47 iw_context,
48 type_converter,
49 }
50 }
51
52 pub fn iw_context(&self) -> &'c Context {
54 self.iw_context
55 }
56}
57
58#[derive(Default)]
59pub struct TypeConverter<'a>(TypeMap<'a, LLVMTypeMapping<'a>>);
60
61impl<'a> TypeConverter<'a> {
62 pub(super) fn custom_type(
63 &mut self,
64 custom_type: (ExtensionId, TypeName),
65 handler: impl LLVMCustomTypeFn<'a>,
66 ) {
67 self.0.set_callback(custom_type, handler);
68 }
69
70 pub fn llvm_type<'c>(
71 self: Rc<Self>,
72 context: TypingSession<'c, 'a>,
73 hugr_type: &HugrType,
74 ) -> Result<BasicTypeEnum<'c>> {
75 self.0.map_type(hugr_type, context)
76 }
77
78 pub fn llvm_func_type<'c>(
79 self: Rc<Self>,
80 context: TypingSession<'c, 'a>,
81 hugr_type: &HugrFuncType,
82 ) -> Result<FunctionType<'c>> {
83 self.0.map_function_type(hugr_type, context)
84 }
85
86 pub fn llvm_sum_type<'c>(
87 self: Rc<Self>,
88 context: TypingSession<'c, 'a>,
89 hugr_type: HugrSumType,
90 ) -> Result<LLVMSumType<'c>> {
91 self.0.map_sum_type(&hugr_type, context)
92 }
93
94 pub fn session<'c>(self: Rc<Self>, iw_context: &'c Context) -> TypingSession<'c, 'a> {
95 TypingSession::new(iw_context, self)
96 }
97}
98
99#[cfg(test)]
100#[allow(drop_bounds)]
101pub mod test {
102
103 use hugr_core::{
104 std_extensions::arithmetic::int_types::INT_TYPES,
105 type_row,
106 types::{SumType, Type},
107 };
108
109 use insta::assert_snapshot;
110 use rstest::rstest;
111
112 use crate::{extension::int::add_int_extensions, test::*, types::HugrFuncType};
113
114 #[rstest]
115 #[case(0,HugrFuncType::new(type_row!(Type::new_unit_sum(2)), type_row!()))]
116 #[case(1, HugrFuncType::new(Type::new_unit_sum(1), Type::new_unit_sum(3)))]
117 #[case(2,HugrFuncType::new(vec![], vec![Type::new_unit_sum(1), Type::new_unit_sum(1)]))]
118 fn func_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] ft: HugrFuncType) {
119 assert_snapshot!(
120 "func_type_to_llvm",
121 llvm_ctx.get_typing_session().llvm_func_type(&ft).unwrap(),
122 &ft.to_string()
123 )
124 }
125
126 #[rstest]
127 #[case(0, SumType::new_unary(0))]
128 #[case(1, SumType::new_unary(1))]
129 #[case(2,SumType::new([vec![Type::new_unit_sum(4), Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), Type::new_unit_sum(3)]]))]
130 #[case(3, SumType::new_unary(2))]
131 fn sum_types(#[case] _id: i32, #[with(_id)] llvm_ctx: TestContext, #[case] st: SumType) {
132 assert_snapshot!(
133 "sum_type_to_llvm",
134 llvm_ctx
135 .get_typing_session()
136 .llvm_sum_type(st.clone())
137 .unwrap(),
138 &st.to_string()
139 )
140 }
141
142 #[rstest]
143 #[case(0, INT_TYPES[0].clone())]
144 #[case(1, INT_TYPES[3].clone())]
145 #[case(2, INT_TYPES[4].clone())]
146 #[case(3, INT_TYPES[5].clone())]
147 #[case(4, INT_TYPES[6].clone())]
148 #[case(5, Type::new_sum([vec![INT_TYPES[2].clone()]]))]
149 #[case(6, Type::new_sum([vec![INT_TYPES[6].clone(),Type::new_unit_sum(1)], vec![Type::new_unit_sum(2), INT_TYPES[2].clone()]]))]
150 #[case(7, Type::new_function(HugrFuncType::new(type_row!(Type::new_unit_sum(2)), Type::new_unit_sum(3))))]
151 fn ext_types(#[case] _id: i32, #[with(_id)] mut llvm_ctx: TestContext, #[case] t: Type) {
152 llvm_ctx.add_extensions(add_int_extensions);
153 assert_snapshot!(
154 "type_to_llvm",
155 llvm_ctx.get_typing_session().llvm_type(&t).unwrap(),
156 &t.to_string()
157 );
158 }
159}