hugr_llvm/utils/
type_map.rs

1//! Provides a generic mapping from [`HugrType`] into some domain.
2use std::collections::BTreeMap;
3
4use hugr_core::{
5    extension::ExtensionId,
6    types::{CustomType, TypeEnum, TypeName, TypeRow},
7};
8
9use anyhow::{Result, bail};
10
11use crate::types::{HugrFuncType, HugrSumType, HugrType};
12
13pub trait TypeMapFnHelper<'c, TM: TypeMapping>:
14    Fn(TM::InV<'c>, &CustomType) -> Result<TM::OutV<'c>>
15{
16}
17
18impl<'c, TM: TypeMapping, F> TypeMapFnHelper<'c, TM> for F where
19    F: Fn(TM::InV<'c>, &CustomType) -> Result<TM::OutV<'c>> + ?Sized
20{
21}
22
23/// A helper trait to name the type of the Callback used by
24/// [`TypeMap<TM>`](TypeMap).
25pub trait TypeMappingFn<'a, TM: TypeMapping>: 'a {
26    fn map_type<'c>(&self, inv: TM::InV<'c>, ty: &CustomType) -> Result<TM::OutV<'c>>;
27}
28
29impl<'a, TM: TypeMapping, F: for<'c> TypeMapFnHelper<'c, TM> + 'a> TypeMappingFn<'a, TM> for F {
30    fn map_type<'c>(&self, inv: TM::InV<'c>, ty: &CustomType) -> Result<TM::OutV<'c>> {
31        self(inv, ty)
32    }
33}
34
35/// Defines a mapping from [`HugrType`] to `OutV`;
36pub trait TypeMapping {
37    /// Auxiliary data provided when mapping from a [`HugrType`].
38    type InV<'c>: Clone;
39    /// The target type of the mapping.
40    type OutV<'c>;
41    /// The target type when mapping from [`HugrSumType`]s. This type must be
42    /// convertible to `OutV` via `sum_into_out`.
43    type SumOutV<'c>;
44    /// The target type when mapping from [`HugrFuncType`]s. This type must be
45    /// convertible to `OutV` via `func_into_out`.
46    type FuncOutV<'c>;
47
48    /// Returns the result of the mapping on `sum_type`, with auxiliary data
49    /// `inv`, and when the result of mapping all fields of all variants is
50    /// given by `variants`.
51    fn map_sum_type<'c>(
52        &self,
53        sum_type: &HugrSumType,
54        inv: Self::InV<'c>,
55        variants: impl IntoIterator<Item = Vec<Self::OutV<'c>>>,
56    ) -> Result<Self::SumOutV<'c>>;
57
58    /// Returns the result of the mapping on `function_type`, with auxiliary data
59    /// `inv`, and when the result of mapping all inputs is given by `inputs`
60    /// and the result of mapping all outputs is given by `outputs`.
61    fn map_function_type<'c>(
62        &self,
63        function_type: &HugrFuncType,
64        inv: Self::InV<'c>,
65        inputs: impl IntoIterator<Item = Self::OutV<'c>>,
66        outputs: impl IntoIterator<Item = Self::OutV<'c>>,
67    ) -> Result<Self::FuncOutV<'c>>;
68
69    /// Infallibly convert from the result of `map_sum_type` to the result of
70    /// the mapping.
71    fn sum_into_out<'c>(&self, sum: Self::SumOutV<'c>) -> Self::OutV<'c>;
72
73    /// Infallibly convert from the result of `map_functype` to the result of
74    /// the mapping.
75    fn func_into_out<'c>(&self, sum: Self::FuncOutV<'c>) -> Self::OutV<'c>;
76
77    /// Construct an appropriate result of the mapping when `hugr_type` is not a
78    /// function, sum, registered custom type, or composition of same.
79    fn default_out<'c>(
80        &self,
81        #[allow(unused)] inv: Self::InV<'c>,
82        hugr_type: &HugrType,
83    ) -> Result<Self::OutV<'c>> {
84        bail!("Unknown type: {hugr_type}")
85    }
86}
87
88pub type CustomTypeKey = (ExtensionId, TypeName);
89
90/// An impl of `TypeMapping` together with a collection of callbacks
91/// implementing the mapping.
92///
93/// Callbacks may hold references with lifetimes longer than `'a`
94#[derive(Default)]
95pub struct TypeMap<'a, TM: TypeMapping> {
96    type_map: TM,
97    custom_hooks: BTreeMap<CustomTypeKey, Box<dyn TypeMappingFn<'a, TM> + 'a>>,
98}
99
100impl<'a, TM: TypeMapping + 'a> TypeMap<'a, TM> {
101    /// Sets the callback for the given custom type.
102    ///
103    /// Returns false if this callback replaces another callback, which is
104    /// discarded, and true otherwise.
105    pub fn set_callback(
106        &mut self,
107        custom_type_key: CustomTypeKey,
108        hook: impl TypeMappingFn<'a, TM> + 'a,
109    ) -> bool {
110        self.custom_hooks
111            .insert(custom_type_key, Box::new(hook))
112            .is_none()
113    }
114
115    /// Map `hugr_type` using the [`TypeMapping`] `TM`, the registered callbacks,
116    /// and the auxiliary data `inv`.
117    pub fn map_type<'c>(&self, hugr_type: &HugrType, inv: TM::InV<'c>) -> Result<TM::OutV<'c>> {
118        match hugr_type.as_type_enum() {
119            TypeEnum::Extension(custom_type) => {
120                let key = (custom_type.extension().clone(), custom_type.name().clone());
121                let Some(handler) = self.custom_hooks.get(&key) else {
122                    return self.type_map.default_out(inv, &custom_type.clone().into());
123                };
124                handler.map_type(inv, custom_type)
125            }
126            TypeEnum::Sum(sum_type) => self
127                .map_sum_type(sum_type, inv)
128                .map(|x| self.type_map.sum_into_out(x)),
129            TypeEnum::Function(function_type) => self
130                .map_function_type(&function_type.as_ref().clone().try_into()?, inv)
131                .map(|x| self.type_map.func_into_out(x)),
132            _ => self.type_map.default_out(inv, hugr_type),
133        }
134    }
135
136    /// As `map_type`, but maps a [`HugrSumType`] to an [`TypeMapping::SumOutV`].
137    pub fn map_sum_type<'c>(
138        &self,
139        sum_type: &HugrSumType,
140        inv: TM::InV<'c>,
141    ) -> Result<TM::SumOutV<'c>> {
142        let inv2 = inv.clone();
143        self.type_map.map_sum_type(
144            sum_type,
145            inv,
146            (0..sum_type.num_variants())
147                .map(move |i| {
148                    let tr: TypeRow = sum_type.get_variant(i).unwrap().clone().try_into().unwrap();
149                    tr.iter()
150                        .map(|t| self.map_type(t, inv2.clone()))
151                        .collect::<Result<Vec<_>>>()
152                })
153                .collect::<Result<Vec<_>>>()?,
154        )
155    }
156
157    /// As `map_type`, but maps a [`HugrSumType`] to an [`TypeMapping::FuncOutV`].
158    pub fn map_function_type<'c>(
159        &self,
160        func_type: &HugrFuncType,
161        inv: TM::InV<'c>,
162    ) -> Result<TM::FuncOutV<'c>> {
163        let inputs = func_type
164            .input()
165            .iter()
166            .map(|t| self.map_type(t, inv.clone()))
167            .collect::<Result<Vec<_>>>()?;
168        let outputs = func_type
169            .output()
170            .iter()
171            .map(|t| self.map_type(t, inv.clone()))
172            .collect::<Result<Vec<_>>>()?;
173        self.type_map
174            .map_function_type(func_type, inv, inputs, outputs)
175    }
176}