hugr_llvm/utils/
type_map.rs1use std::collections::BTreeMap;
3
4use hugr_core::{
5 extension::ExtensionId,
6 types::{CustomType, TypeEnum, TypeName, TypeRow},
7};
8
9use anyhow::{Result, bail};
10
11use crate::types::{HugrFuncType, HugrSumType, HugrType};
12
13pub trait TypeMapFnHelper<'c, TM: TypeMapping>:
14 Fn(TM::InV<'c>, &CustomType) -> Result<TM::OutV<'c>>
15{
16}
17
18impl<'c, TM: TypeMapping, F> TypeMapFnHelper<'c, TM> for F where
19 F: Fn(TM::InV<'c>, &CustomType) -> Result<TM::OutV<'c>> + ?Sized
20{
21}
22
23pub trait TypeMappingFn<'a, TM: TypeMapping>: 'a {
26 fn map_type<'c>(&self, inv: TM::InV<'c>, ty: &CustomType) -> Result<TM::OutV<'c>>;
27}
28
29impl<'a, TM: TypeMapping, F: for<'c> TypeMapFnHelper<'c, TM> + 'a> TypeMappingFn<'a, TM> for F {
30 fn map_type<'c>(&self, inv: TM::InV<'c>, ty: &CustomType) -> Result<TM::OutV<'c>> {
31 self(inv, ty)
32 }
33}
34
35pub trait TypeMapping {
37 type InV<'c>: Clone;
39 type OutV<'c>;
41 type SumOutV<'c>;
44 type FuncOutV<'c>;
47
48 fn map_sum_type<'c>(
52 &self,
53 sum_type: &HugrSumType,
54 inv: Self::InV<'c>,
55 variants: impl IntoIterator<Item = Vec<Self::OutV<'c>>>,
56 ) -> Result<Self::SumOutV<'c>>;
57
58 fn map_function_type<'c>(
62 &self,
63 function_type: &HugrFuncType,
64 inv: Self::InV<'c>,
65 inputs: impl IntoIterator<Item = Self::OutV<'c>>,
66 outputs: impl IntoIterator<Item = Self::OutV<'c>>,
67 ) -> Result<Self::FuncOutV<'c>>;
68
69 fn sum_into_out<'c>(&self, sum: Self::SumOutV<'c>) -> Self::OutV<'c>;
72
73 fn func_into_out<'c>(&self, sum: Self::FuncOutV<'c>) -> Self::OutV<'c>;
76
77 fn default_out<'c>(
80 &self,
81 #[allow(unused)] inv: Self::InV<'c>,
82 hugr_type: &HugrType,
83 ) -> Result<Self::OutV<'c>> {
84 bail!("Unknown type: {hugr_type}")
85 }
86}
87
88pub type CustomTypeKey = (ExtensionId, TypeName);
89
90#[derive(Default)]
95pub struct TypeMap<'a, TM: TypeMapping> {
96 type_map: TM,
97 custom_hooks: BTreeMap<CustomTypeKey, Box<dyn TypeMappingFn<'a, TM> + 'a>>,
98}
99
100impl<'a, TM: TypeMapping + 'a> TypeMap<'a, TM> {
101 pub fn set_callback(
106 &mut self,
107 custom_type_key: CustomTypeKey,
108 hook: impl TypeMappingFn<'a, TM> + 'a,
109 ) -> bool {
110 self.custom_hooks
111 .insert(custom_type_key, Box::new(hook))
112 .is_none()
113 }
114
115 pub fn map_type<'c>(&self, hugr_type: &HugrType, inv: TM::InV<'c>) -> Result<TM::OutV<'c>> {
118 match hugr_type.as_type_enum() {
119 TypeEnum::Extension(custom_type) => {
120 let key = (custom_type.extension().clone(), custom_type.name().clone());
121 let Some(handler) = self.custom_hooks.get(&key) else {
122 return self.type_map.default_out(inv, &custom_type.clone().into());
123 };
124 handler.map_type(inv, custom_type)
125 }
126 TypeEnum::Sum(sum_type) => self
127 .map_sum_type(sum_type, inv)
128 .map(|x| self.type_map.sum_into_out(x)),
129 TypeEnum::Function(function_type) => self
130 .map_function_type(&function_type.as_ref().clone().try_into()?, inv)
131 .map(|x| self.type_map.func_into_out(x)),
132 _ => self.type_map.default_out(inv, hugr_type),
133 }
134 }
135
136 pub fn map_sum_type<'c>(
138 &self,
139 sum_type: &HugrSumType,
140 inv: TM::InV<'c>,
141 ) -> Result<TM::SumOutV<'c>> {
142 let inv2 = inv.clone();
143 self.type_map.map_sum_type(
144 sum_type,
145 inv,
146 (0..sum_type.num_variants())
147 .map(move |i| {
148 let tr: TypeRow = sum_type.get_variant(i).unwrap().clone().try_into().unwrap();
149 tr.iter()
150 .map(|t| self.map_type(t, inv2.clone()))
151 .collect::<Result<Vec<_>>>()
152 })
153 .collect::<Result<Vec<_>>>()?,
154 )
155 }
156
157 pub fn map_function_type<'c>(
159 &self,
160 func_type: &HugrFuncType,
161 inv: TM::InV<'c>,
162 ) -> Result<TM::FuncOutV<'c>> {
163 let inputs = func_type
164 .input()
165 .iter()
166 .map(|t| self.map_type(t, inv.clone()))
167 .collect::<Result<Vec<_>>>()?;
168 let outputs = func_type
169 .output()
170 .iter()
171 .map(|t| self.map_type(t, inv.clone()))
172 .collect::<Result<Vec<_>>>()?;
173 self.type_map
174 .map_function_type(func_type, inv, inputs, outputs)
175 }
176}