hugr_llvm/custom/
load_constant.rs

1//! Provides the implementation for a collection of [CustomConst] callbacks.
2use std::{any::TypeId, collections::HashMap};
3
4use hugr_core::{ops::constant::CustomConst, HugrView, Node};
5use inkwell::values::BasicValueEnum;
6
7use anyhow::{anyhow, bail, ensure, Result};
8
9use crate::emit::EmitFuncContext;
10
11/// A helper trait for describing the callback used for emitting [CustomConst]s,
12/// and for hanging documentation. We have the appropriate `Fn` as a supertrait,
13/// and there is a blanket impl for that `Fn`. We do not intend users to impl
14/// this trait.
15///
16/// `LoadConstantFn` callbacks for `CC`, which must impl [CustomConst], should
17/// materialise an appropriate [BasicValueEnum]. The type of this value must
18/// match the result of [EmitFuncContext::llvm_type] on [CustomConst::get_type].
19///
20/// Callbacks may hold references with lifetimes older than `'a`.
21pub trait LoadConstantFn<'a, H: ?Sized, CC: CustomConst + ?Sized>:
22    for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, &CC) -> Result<BasicValueEnum<'c>> + 'a
23{
24}
25
26impl<
27        'a,
28        H: ?Sized,
29        CC: ?Sized + CustomConst,
30        F: 'a
31            + ?Sized
32            + for<'c> Fn(&mut EmitFuncContext<'c, 'a, H>, &CC) -> Result<BasicValueEnum<'c>>,
33    > LoadConstantFn<'a, H, CC> for F
34{
35}
36
37/// A collection of [LoadConstantFn] callbacks registered for various impls of [CustomConst].
38/// The callbacks are keyed by the [TypeId] of those impls.
39#[derive(Default)]
40pub struct LoadConstantsMap<'a, H>(
41    HashMap<TypeId, Box<dyn LoadConstantFn<'a, H, dyn CustomConst>>>,
42);
43
44impl<'a, H: HugrView<Node = Node>> LoadConstantsMap<'a, H> {
45    /// Register a callback to emit a `CC` value.
46    ///
47    /// If a callback is already registered for that type, we will replace it.
48    pub fn custom_const<CC: CustomConst>(&mut self, handler: impl LoadConstantFn<'a, H, CC>) {
49        self.0.insert(
50            TypeId::of::<CC>(),
51            Box::new(move |context, konst: &dyn CustomConst| {
52                let cc = konst.downcast_ref::<CC>().ok_or(anyhow!(
53                    "impossible! Failed to downcast in LoadConstantsMap::custom_const"
54                ))?;
55                handler(context, cc)
56            }),
57        );
58    }
59
60    /// Emit instructions to materialise `konst` by delegating to the
61    /// appropriate inner callbacks.
62    pub fn emit_load_constant<'c>(
63        &self,
64        context: &mut EmitFuncContext<'c, 'a, H>,
65        konst: &dyn CustomConst,
66    ) -> Result<BasicValueEnum<'c>> {
67        let type_id = konst.type_id();
68        let Some(handler) = self.0.get(&type_id) else {
69            bail!(
70                "No extension could load constant name: {} type_id: {type_id:?}",
71                konst.name()
72            )
73        };
74        let r = handler(context, konst)?;
75        let r_type = r.get_type();
76        let konst_type = context.llvm_type(&konst.get_type())?;
77        ensure!(r_type == konst_type, "CustomConst handler returned a value of the wrong type. Expected: {konst_type} Actual: {r_type}");
78        Ok(r)
79    }
80}