hugr_llvm/custom/
types.rs1use std::marker::PhantomData;
2
3use itertools::Itertools as _;
4
5use hugr_core::types::CustomType;
6
7use anyhow::Result;
8use inkwell::types::{BasicMetadataTypeEnum, BasicType as _, BasicTypeEnum, FunctionType};
9
10pub use crate::utils::type_map::CustomTypeKey;
11
12use crate::{
13 sum::LLVMSumType,
14 types::{HugrFuncType, HugrSumType, TypingSession},
15 utils::type_map::TypeMapping,
16};
17
18pub trait LLVMCustomTypeFn<'a>:
19 for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a
20{
21}
22
23impl<
24 'a,
25 F: for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a + ?Sized,
26> LLVMCustomTypeFn<'a> for F
27{
28}
29
30#[derive(Default, Clone)]
31pub struct LLVMTypeMapping<'a>(PhantomData<&'a ()>);
32
33impl<'a> TypeMapping for LLVMTypeMapping<'a> {
34 type InV<'c> = TypingSession<'c, 'a>;
35
36 type OutV<'c> = BasicTypeEnum<'c>;
37
38 type SumOutV<'c> = LLVMSumType<'c>;
39
40 type FuncOutV<'c> = FunctionType<'c>;
41
42 fn sum_into_out<'c>(&self, sum: Self::SumOutV<'c>) -> Self::OutV<'c> {
43 sum.as_basic_type_enum()
44 }
45
46 fn func_into_out<'c>(&self, sum: Self::FuncOutV<'c>) -> Self::OutV<'c> {
47 sum.ptr_type(Default::default()).as_basic_type_enum()
48 }
49
50 fn map_sum_type<'c>(
51 &self,
52 _sum_type: &HugrSumType,
53 context: TypingSession<'c, 'a>,
54 variants: impl IntoIterator<Item = Vec<Self::OutV<'c>>>,
55 ) -> Result<Self::SumOutV<'c>> {
56 LLVMSumType::try_new(context.iw_context(), variants.into_iter().collect_vec())
57 }
58
59 fn map_function_type<'c>(
60 &self,
61 _: &HugrFuncType,
62 context: TypingSession<'c, 'a>,
63 inputs: impl IntoIterator<Item = Self::OutV<'c>>,
64 outputs: impl IntoIterator<Item = Self::OutV<'c>>,
65 ) -> Result<Self::FuncOutV<'c>> {
66 let iw_context = context.iw_context();
67 let inputs: Vec<BasicMetadataTypeEnum<'c>> = inputs.into_iter().map_into().collect_vec();
68 let outputs = outputs.into_iter().collect_vec();
69 Ok(match outputs.as_slice() {
70 &[] => iw_context.void_type().fn_type(&inputs, false),
71 [res] => res.fn_type(&inputs, false),
72 ress => iw_context.struct_type(ress, false).fn_type(&inputs, false),
73 })
74 }
75}