hugr_llvm/custom/
types.rs

1use std::marker::PhantomData;
2
3use itertools::Itertools as _;
4
5use hugr_core::types::CustomType;
6
7use anyhow::Result;
8use inkwell::types::{BasicMetadataTypeEnum, BasicType as _, BasicTypeEnum, FunctionType};
9
10pub use crate::utils::type_map::CustomTypeKey;
11
12use crate::{
13    sum::LLVMSumType,
14    types::{HugrFuncType, HugrSumType, TypingSession},
15    utils::type_map::TypeMapping,
16};
17
18pub trait LLVMCustomTypeFn<'a>:
19    for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a
20{
21}
22
23impl<
24    'a,
25    F: for<'c> Fn(TypingSession<'c, 'a>, &CustomType) -> Result<BasicTypeEnum<'c>> + 'a + ?Sized,
26> LLVMCustomTypeFn<'a> for F
27{
28}
29
30#[derive(Default, Clone)]
31pub struct LLVMTypeMapping<'a>(PhantomData<&'a ()>);
32
33impl<'a> TypeMapping for LLVMTypeMapping<'a> {
34    type InV<'c> = TypingSession<'c, 'a>;
35
36    type OutV<'c> = BasicTypeEnum<'c>;
37
38    type SumOutV<'c> = LLVMSumType<'c>;
39
40    type FuncOutV<'c> = FunctionType<'c>;
41
42    fn sum_into_out<'c>(&self, sum: Self::SumOutV<'c>) -> Self::OutV<'c> {
43        sum.as_basic_type_enum()
44    }
45
46    fn func_into_out<'c>(&self, sum: Self::FuncOutV<'c>) -> Self::OutV<'c> {
47        sum.ptr_type(Default::default()).as_basic_type_enum()
48    }
49
50    fn map_sum_type<'c>(
51        &self,
52        _sum_type: &HugrSumType,
53        context: TypingSession<'c, 'a>,
54        variants: impl IntoIterator<Item = Vec<Self::OutV<'c>>>,
55    ) -> Result<Self::SumOutV<'c>> {
56        LLVMSumType::try_new(context.iw_context(), variants.into_iter().collect_vec())
57    }
58
59    fn map_function_type<'c>(
60        &self,
61        _: &HugrFuncType,
62        context: TypingSession<'c, 'a>,
63        inputs: impl IntoIterator<Item = Self::OutV<'c>>,
64        outputs: impl IntoIterator<Item = Self::OutV<'c>>,
65    ) -> Result<Self::FuncOutV<'c>> {
66        let iw_context = context.iw_context();
67        let inputs: Vec<BasicMetadataTypeEnum<'c>> = inputs.into_iter().map_into().collect_vec();
68        let outputs = outputs.into_iter().collect_vec();
69        Ok(match outputs.as_slice() {
70            &[] => iw_context.void_type().fn_type(&inputs, false),
71            [res] => res.fn_type(&inputs, false),
72            ress => iw_context.struct_type(ress, false).fn_type(&inputs, false),
73        })
74    }
75}